Axis names are now tracked via an effect

This allows propagating the names bottom up -- from equations to the jaxpr,
instead of "discovering" them top-down by traversing (and rebuilding) the
jaxpr via core.subst_axis_names.

PiperOrigin-RevId: 612416803
This commit is contained in:
Sergei Lebedev 2024-03-04 05:41:29 -08:00 committed by jax authors
parent 2dd5e9e180
commit 5283d4b4a5
8 changed files with 145 additions and 58 deletions

View File

@ -569,6 +569,10 @@ def xla_computation(fun: Callable,
stack.enter_context(core.extend_axis_env(axis_name, size, None))
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
if axis_env:
jaxpr = core.remove_named_axis_effects(
jaxpr, {axis_name for axis_name, _ in axis_env}
)
axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
ordered_effects = list(
effects.ordered_effects.filter_in(jaxpr.effects))

View File

@ -15,8 +15,9 @@ from __future__ import annotations
import collections # noqa: F401
from collections import Counter, defaultdict, deque, namedtuple
from collections.abc import (Generator, Hashable, Iterable, Iterator, Sequence,
MutableSet, MutableMapping)
from collections.abc import (Collection, Generator, Hashable, Iterable,
Iterator, Set, Sequence, MutableSet,
MutableMapping)
from contextlib import contextmanager
from dataclasses import dataclass
import functools
@ -2637,6 +2638,34 @@ def axis_frame(axis_name: AxisName, main_trace: MainTrace | None = None
f'by pmap) are available to collective operations: {named_axes}')
@dataclass(frozen=True)
class NamedAxisEffect(effects.Effect):
"""A side-effect introducing a new named axis into the current scope."""
name: AxisName
effects.control_flow_allowed_effects.add_type(NamedAxisEffect)
effects.custom_derivatives_allowed_effects.add_type(NamedAxisEffect)
effects.lowerable_effects.add_type(NamedAxisEffect)
effects.remat_allowed_effects.add_type(NamedAxisEffect)
def filter_named_axis_effects(
effects: Effects, names: Collection[AxisName]
) -> Effects:
return {e for e in effects
if not isinstance(e, NamedAxisEffect) or e.name not in names}
def remove_named_axis_effects(
jaxpr: Jaxpr, names: Collection[AxisName]
) -> Jaxpr:
if not names or not jaxpr.effects:
return jaxpr
return jaxpr.replace(effects=filter_named_axis_effects(jaxpr.effects, names))
ParamDict = dict[str, Any]
AxisSubst = Callable[[AxisName], tuple[AxisName, ...]]
@ -2676,6 +2705,15 @@ class DuplicateAxisNameError(Exception):
self.var = var
self.eqn = None
def subst_axis_names_effects(effects: Set[Effect], subst: AxisSubst) -> Set[Effect]:
new_effects = set[Effect]()
for e in effects:
if isinstance(e, NamedAxisEffect):
new_effects.update(map(NamedAxisEffect, subst(e.name)))
else:
new_effects.add(e)
return new_effects
def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: dict[Var, Var]) -> Var:
# Var identity is load-bearing, so we can't have duplicates!
if isinstance(v, DropVar): return v
@ -2699,7 +2737,8 @@ def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: dict[Var, Var
e.eqn = eqn
raise
params = subst_axis_names(eqn.primitive, eqn.params, subst)
return eqn.replace(invars=invars, outvars=outvars, params=params)
effects = subst_axis_names_effects(eqn.effects, subst)
return eqn.replace(invars=invars, outvars=outvars, params=params, effects=effects)
def do_subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst):
consts = None
@ -2711,16 +2750,14 @@ def do_subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst):
constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars] # type: ignore[union-attr]
eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns] # type: ignore[union-attr]
outvars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars] # type: ignore[union-attr]
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, jaxpr.effects)
effects = subst_axis_names_effects(jaxpr.effects, subst)
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, effects)
if consts is not None:
return ClosedJaxpr(new_jaxpr, consts)
return new_jaxpr
@weakref_lru_cache
def used_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr):
subst = NameGatheringSubst()
do_subst_axis_names_jaxpr(jaxpr, subst)
return frozenset(subst.axis_names)
return {e.name for e in jaxpr.effects if isinstance(e, NamedAxisEffect)}
def subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst):
if isinstance(subst, NameGatheringSubst): # This is a common case, so we optimize it!
@ -2924,6 +2961,9 @@ def _check_jaxpr(
raise JaxprTypeError(
"Invalid `JaxprInputEffect`: must be present in jaxpr. "
f"{jaxpr_effect} is not in {jaxpr.effects}.")
elif isinstance(eff, NamedAxisEffect):
# It is valid for a primitive to discharge the named axis effect.
continue
elif eff not in jaxpr.effects:
raise JaxprTypeError("Equation effect not present in jaxpr effects. "
f"Equation effect: {eff}. "
@ -3077,7 +3117,7 @@ def _check_map(ctx_factory, prim, in_avals, params):
out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval)
if out_axis is not None else aval
for aval, out_axis in zip(mapped_out_avals, out_axes)]
return out_avals, call_jaxpr.effects
return out_avals, filter_named_axis_effects(call_jaxpr.effects, {axis_name})
# ------------------- Jaxpr printed representation -------------------

View File

@ -2570,8 +2570,9 @@ def build_mlir_module_helper(
platforms: Sequence[str],
backend_or_name: str, axis_context: AxisContext) -> ir.Module:
"""Helper to generate pmap-style XLA computations for custom partitioners."""
if closed_jaxpr.effects:
raise NotImplementedError
unlowerable_effects = lowerable_effects.filter_not_in(closed_jaxpr.effects)
if unlowerable_effects:
raise ValueError(f'Cannot lower jaxpr with effects: {closed_jaxpr.effects}')
lowering_result = lower_jaxpr_to_module(name, closed_jaxpr,
backend_or_name=backend_or_name, ordered_effects=[],
name_stack=source_info_util.NameStack(),

View File

@ -380,10 +380,10 @@ class JaxprTrace(Trace['JaxprTracer']):
for ax, a in zip(staged_out_axes, out_avals_mapped)]
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
for a in out_avals]
effs = core.filter_named_axis_effects(jaxpr.effects, {params['axis_name']})
src_info = source_info_util.current()
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), # type: ignore[arg-type]
out_tracers, primitive, staged_params,
jaxpr.effects,
source_info_util.current())
out_tracers, primitive, staged_params, effs, src_info)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
@ -2099,8 +2099,9 @@ class DynamicJaxprTrace(core.Trace):
update_params = call_param_updaters.get(map_primitive)
if update_params:
new_params = update_params(new_params, [True] * len(tracers), len(consts))
effs = core.filter_named_axis_effects(jaxpr.effects, {axis_name})
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
new_params, jaxpr.effects, source_info)
new_params, effs, source_info)
self.frame.add_eqn(eqn)
return out_tracers

View File

@ -766,6 +766,7 @@ def lower_parallel_callable(
axis_env = sharding_impls.AxisEnv(
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap'))
jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
replicated_args = [axis is None for axis in in_axes]
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
@ -1258,8 +1259,8 @@ def _pmap_partial_eval_custom_res_maker(params_known, aval):
def _pmap_dce_rule(used_outputs, eqn):
# just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes
with maybe_extend_axis_env(eqn.params['axis_name'],
eqn.params['global_axis_size'], None):
axis_name = eqn.params["axis_name"]
with maybe_extend_axis_env(axis_name, eqn.params["global_axis_size"], None):
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
_, donated_invars = partition_list(used_inputs, eqn.params['donated_invars'])
_, in_axes = partition_list(used_inputs, eqn.params['in_axes'])
@ -1270,10 +1271,11 @@ def _pmap_dce_rule(used_outputs, eqn):
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
return used_inputs, None
else:
effs = core.filter_named_axis_effects(new_jaxpr.effects, {axis_name})
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
eqn.primitive, new_params, effs, eqn.source_info)
return used_inputs, new_eqn
@ -2238,6 +2240,7 @@ def lower_mesh_computation(
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
num_replicas = mesh.devices.size
num_partitions = 1
jaxpr = core.remove_named_axis_effects(jaxpr, mesh.axis_names)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module_name = f"{api_name}_{fun_name}"
with core.extend_axis_env_nd(mesh.shape.items()):

View File

@ -23,15 +23,12 @@ import itertools
import math
import string
import numpy as np
from jax import tree_util
from jax._src import core
from jax._src import dtypes
from jax._src import sharding_impls
from jax._src import util
from jax._src.core import ShapedArray, AxisName, raise_to_shaped
from jax._src.core import AxisName, ShapedArray, raise_to_shaped
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -41,8 +38,9 @@ from jax._src.lax import slicing
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy import lax_numpy
from jax._src.util import (
unzip2, canonicalize_axis, safe_map, safe_zip, moveaxis)
from jax._src.util import (canonicalize_axis, moveaxis, safe_map, safe_zip,
unzip2)
import numpy as np
unsafe_map, map = map, safe_map # type: ignore
@ -709,21 +707,23 @@ def _allreduce_impl(pos_reducer, *args, axes, axis_index_groups):
assert all(isinstance(axis, int) for axis in axes)
return [pos_reducer(arg, axes) for arg in args]
def _allreduce_abstract_eval(*args, axes, axis_index_groups):
def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
# TODO(frostig,mattjj,jekbradbury): maybe check aval names here
pos_axes = tuple(axis for axis in axes if isinstance(axis, int))
named_shapes = [arg.named_shape for arg in args]
if axis_index_groups is None:
named_axes = {axis for axis in axes if not isinstance(axis, int)}
if axis_index_groups is None:
named_shapes = [{name: size for name, size in arg.named_shape.items()
if name not in named_axes} for arg in args]
else:
if len(pos_axes) != 0:
raise ValueError(f"axis_index_groups can only be used with reductions over "
f"named axes, but got: {axes}")
return [ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes),
out_avals = [
ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes),
arg.dtype, named_shape=named_shape)
for arg, named_shape in zip(args, named_shapes)]
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}
def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms):
@ -804,7 +804,7 @@ def _psum_transpose_rule(cts, *args, axes, axis_index_groups):
psum_p = core.AxisPrimitive('psum')
psum_p.multiple_results = True
psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum))
psum_p.def_abstract_eval(_allreduce_abstract_eval)
psum_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
mlir.register_lowering(
psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum))
ad.deflinear2(psum_p, _psum_transpose_rule)
@ -840,7 +840,7 @@ def psum_bind(*args, axes, axis_index_groups):
pmax_p = core.AxisPrimitive('pmax')
pmax_p.multiple_results = True
pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max))
pmax_p.def_abstract_eval(_allreduce_abstract_eval)
pmax_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
mlir.register_lowering(
pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max))
batching.primitive_batchers[pmax_p] = partial(_reduction_batcher, pmax_p)
@ -852,7 +852,7 @@ core.axis_substitution_rules[pmax_p] = partial(_subst_all_names_in_param, 'axes'
pmin_p = core.AxisPrimitive('pmin')
pmin_p.multiple_results = True
pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min))
pmin_p.def_abstract_eval(_allreduce_abstract_eval)
pmin_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
mlir.register_lowering(
pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min))
batching.primitive_batchers[pmin_p] = partial(_reduction_batcher, pmin_p)
@ -1060,17 +1060,25 @@ def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in,
new_d -= 1 # We've removed 0th dimension, so new_d needs to be adjusted
return x, new_d
def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_groups):
def _all_to_all_effectful_abstract_eval(
x, axis_name, split_axis, concat_axis, axis_index_groups
):
if not isinstance(axis_name, (list, tuple)):
axis_name = (axis_name,)
input_aval = raise_to_shaped(x)
shape = list(input_aval.shape)
axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0])
assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size)
shape[split_axis] //= axis_size
shape[concat_axis] *= axis_size
return input_aval.update(shape=tuple(shape), weak_type=False)
out_aval = input_aval.update(shape=tuple(shape), weak_type=False)
effects = {*map(core.NamedAxisEffect, axis_name)}
return out_aval, effects
all_to_all_p = core.AxisPrimitive('all_to_all')
all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval)
all_to_all_p.def_effectful_abstract_eval(_all_to_all_effectful_abstract_eval)
mlir.register_lowering(all_to_all_p, _all_to_all_lowering)
ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule)
batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher
@ -1204,7 +1212,10 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
replica_groups=_replica_groups_hlo(replica_groups),
**other_args).results
def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
def _all_gather_effectful_abstract_eval(
x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled
):
if not isinstance(axis_name, (list, tuple)):
axis_name = (axis_name,)
x_aval = raise_to_shaped(x)
@ -1215,7 +1226,10 @@ def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_
new_shape.insert(all_gather_dimension, axis_size)
new_named_shape = {name: size for name, size in x_aval.named_shape.items()
if name not in axis_name}
return x_aval.update(shape=new_shape, named_shape=new_named_shape)
out_aval = x_aval.update(shape=new_shape, named_shape=new_named_shape)
effects = {*map(core.NamedAxisEffect, axis_name)}
return out_aval, effects
def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
return (psum_scatter(cts, axis_name=axis_name,
@ -1264,7 +1278,7 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in,
return y, batching.not_mapped
all_gather_p = core.AxisPrimitive('all_gather')
all_gather_p.def_abstract_eval(_all_gather_abstract_eval)
all_gather_p.def_effectful_abstract_eval(_all_gather_effectful_abstract_eval)
all_gather_p.def_impl(_all_gather_impl)
mlir.register_lowering(all_gather_p, _all_gather_lowering)
for p in ("cuda", "rocm", "tpu"):
@ -1327,9 +1341,9 @@ def _reduce_scatter_lowering(
return [hlo.reshape(mlir.aval_to_ir_type(aval_out), op.result)]
def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
axis_index_groups, axis_size, tiled):
def _reduce_scatter_effectful_abstract_eval(
x, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled
):
if not isinstance(axis_name, (list, tuple)):
axis_name = (axis_name,)
x_aval = core.raise_to_shaped(x)
@ -1353,7 +1367,9 @@ def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
for name, size in x_aval.named_shape.items()
if name not in axis_name
}
return x_aval.update(shape=new_shape, named_shape=new_named_shape)
out_aval = x_aval.update(shape=new_shape, named_shape=new_named_shape)
effects = {*map(core.NamedAxisEffect, axis_name)}
return out_aval, effects
def _reduce_scatter_transpose_rule(cts, x, *, axis_name, scatter_dimension,
@ -1401,7 +1417,9 @@ def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in,
reduce_scatter_p = core.AxisPrimitive("reduce_scatter")
reduce_scatter_p.def_abstract_eval(_reduce_scatter_abstract_eval)
reduce_scatter_p.def_effectful_abstract_eval(
_reduce_scatter_effectful_abstract_eval
)
ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule)
batching.primitive_batchers[reduce_scatter_p] = _reduce_scatter_batcher
batching.axis_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective
@ -1537,13 +1555,15 @@ def _axis_index_lowering(ctx, *, axis_name):
]
def _axis_index_abstract_eval(*, axis_name):
def _axis_index_effectful_abstract_eval(*, axis_name):
frame = core.axis_frame(axis_name)
return ShapedArray((), np.int32, named_shape={axis_name: frame.size})
out_aval = ShapedArray((), np.int32, named_shape={axis_name: frame.size})
return out_aval, {core.NamedAxisEffect(axis_name)}
axis_index_p = core.Primitive('axis_index')
mlir.register_lowering(axis_index_p, _axis_index_lowering)
axis_index_p.def_abstract_eval(_axis_index_abstract_eval)
axis_index_p.def_effectful_abstract_eval(_axis_index_effectful_abstract_eval)
core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param, 'axis_name')
# Axis index doesn't get any arguments, so that the default bind would have no
@ -1585,8 +1605,11 @@ def _pdot_impl(x, y, *, axis_name, pos_contract, pos_batch, precision):
if axis_name: raise NameError(f"unbound axis name: {axis_name[0]}")
return lax.dot_general(x, y, (pos_contract, pos_batch), precision=precision)
@pdot_p.def_abstract_eval
def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch, precision):
@pdot_p.def_effectful_abstract_eval
def _pdot_effectful_abstract_eval(
x, y, *, axis_name, pos_contract, pos_batch, precision
):
# TODO(frostig,mattjj,jekbradbury): check inputs have given axis names?
if not len(set(axis_name)) == len(axis_name): raise ValueError
pos_aval = lax.dot_general_p.abstract_eval(
@ -1596,7 +1619,10 @@ def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch, precision):
named_shape = {name: size
for name, size in common_named_shape.items()
if name not in axis_name}
return pos_aval.update(named_shape=named_shape)
out_aval = pos_aval.update(named_shape=named_shape)
effects = {*map(core.NamedAxisEffect, axis_name)}
return out_aval, effects
def _pdot_vmap_collective_rule(axis_size, frame_name, _, vals_in, dims_in, *, axis_name,
pos_contract, pos_batch, precision):

View File

@ -937,7 +937,8 @@ def _typecheck_xmap(
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)
for a, a_out_axes in zip(mapped_out_avals, out_axes)]
return out_avals, call_jaxpr.effects
effs = core.filter_named_axis_effects(call_jaxpr.effects, global_axis_sizes)
return out_avals, effs
core.custom_typechecks[xmap_p] = _typecheck_xmap
@ -1033,8 +1034,9 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
call_jaxpr=call_jaxpr)
del new_params['out_axes_thunk']
del new_params['spmd_out_axes_thunk']
effs = core.filter_named_axis_effects(call_jaxpr.effects, global_axis_sizes)
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive,
new_params, call_jaxpr.effects, source_info)
new_params, effs, source_info)
self.frame.add_eqn(eqn)
return out_tracers
pe.DynamicJaxprTrace.process_xmap = _dynamic_jaxpr_process_xmap # type: ignore

View File

@ -426,7 +426,9 @@ def _shard_map_staging(
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
main = trace.main
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic(
f, main, in_avals_
)
out_avals_ = map(_check_shapedarray, genavals)
_check_names(out_names_thunk(), out_avals_)
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
@ -445,8 +447,9 @@ def _shard_map_staging(
params = dict(mesh=mesh, in_names=in_names_staged,
out_names=tuple(out_names_thunk()), jaxpr=jaxpr,
check_rep=check_rep, rewrite=rewrite, auto=auto)
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params,
jaxpr.effects, source_info)
effs, source_info)
trace.frame.add_eqn(eqn)
return out_tracers
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
@ -495,7 +498,8 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names,
"sufficiently replicated")
out_avals_sharded = [x.aval for x in jaxpr.outvars]
out_avals = map(partial(_unshard_aval, mesh), out_names, out_avals_sharded)
return out_avals, jaxpr.effects
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
return out_avals, effs
core.custom_typechecks[shard_map_p] = _shard_map_typecheck
def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> set[AxisName]:
@ -1305,9 +1309,10 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
out_avals = map(partial(_unshard_aval, mesh), unk_out_names, out_avals_sharded)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in out_avals]
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers), # type: ignore[arg-type]
out_tracers, shard_map_p, unk_params,
jaxpr.effects, source_info_util.current())
effs, source_info_util.current())
for t in out_tracers: t.recipe = eqn
return pe.merge_lists(out_knowns, out_tracers, out_consts)
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
@ -1345,8 +1350,9 @@ def _shard_map_partial_eval_post_process(
for a in out_avals]
name_stack = trace._current_truncated_name_stack()
source = source_info_util.current().replace(name_stack=name_stack)
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers), out_tracers,
shard_map_p, staged_params, jaxpr.effects, source)
shard_map_p, staged_params, effs, source)
for t in out_tracers: t.recipe = eqn
return merge_lists(out_knowns, out_tracers, out_consts)
@ -1456,6 +1462,8 @@ def _partial_eval_jaxpr_custom_rule(
with core.extend_axis_env_nd(eqn.params['mesh'].shape.items()):
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)
jaxpr_staged = core.remove_named_axis_effects(jaxpr_staged, mesh.axis_names)
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
_, ins_staged = partition_list(inst_in, eqn.invars)
@ -1535,7 +1543,8 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule?
def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], core.JaxprEqn | None]:
with core.extend_axis_env_nd(eqn.params['mesh'].shape.items()):
mesh = eqn.params["mesh"]
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects:
return used_inputs, None
@ -1544,10 +1553,11 @@ def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn
_, out_names = partition_list(used_outputs, eqn.params['out_names'])
new_params = dict(eqn.params, jaxpr=jaxpr, in_names=tuple(in_names),
out_names=tuple(out_names))
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, used_inputs) if used],
[x for x, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, jaxpr.effects, eqn.source_info)
eqn.primitive, new_params, effs, eqn.source_info)
return used_inputs, new_eqn
pe.dce_rules[shard_map_p] = _shard_map_dce