From 5283d4b4a5298c3822039f35454adca4e80ea84b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 4 Mar 2024 05:41:29 -0800 Subject: [PATCH] Axis names are now tracked via an effect This allows propagating the names bottom up -- from equations to the jaxpr, instead of "discovering" them top-down by traversing (and rebuilding) the jaxpr via core.subst_axis_names. PiperOrigin-RevId: 612416803 --- jax/_src/api.py | 4 ++ jax/_src/core.py | 58 +++++++++++++++--- jax/_src/interpreters/mlir.py | 5 +- jax/_src/interpreters/partial_eval.py | 9 +-- jax/_src/interpreters/pxla.py | 9 ++- jax/_src/lax/parallel.py | 88 +++++++++++++++++---------- jax/_src/maps.py | 6 +- jax/experimental/shard_map.py | 24 +++++--- 8 files changed, 145 insertions(+), 58 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index ad6fd2d88..005e0ceca 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -569,6 +569,10 @@ def xla_computation(fun: Callable, stack.enter_context(core.extend_axis_env(axis_name, size, None)) jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals) jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) + if axis_env: + jaxpr = core.remove_named_axis_effects( + jaxpr, {axis_name for axis_name, _ in axis_env} + ) axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr)) ordered_effects = list( effects.ordered_effects.filter_in(jaxpr.effects)) diff --git a/jax/_src/core.py b/jax/_src/core.py index d63600888..454aea0f3 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -15,8 +15,9 @@ from __future__ import annotations import collections # noqa: F401 from collections import Counter, defaultdict, deque, namedtuple -from collections.abc import (Generator, Hashable, Iterable, Iterator, Sequence, - MutableSet, MutableMapping) +from collections.abc import (Collection, Generator, Hashable, Iterable, + Iterator, Set, Sequence, MutableSet, + MutableMapping) from contextlib import contextmanager from dataclasses import dataclass import functools @@ -2637,6 +2638,34 @@ def axis_frame(axis_name: AxisName, main_trace: MainTrace | None = None f'by pmap) are available to collective operations: {named_axes}') +@dataclass(frozen=True) +class NamedAxisEffect(effects.Effect): + """A side-effect introducing a new named axis into the current scope.""" + + name: AxisName + + +effects.control_flow_allowed_effects.add_type(NamedAxisEffect) +effects.custom_derivatives_allowed_effects.add_type(NamedAxisEffect) +effects.lowerable_effects.add_type(NamedAxisEffect) +effects.remat_allowed_effects.add_type(NamedAxisEffect) + + +def filter_named_axis_effects( + effects: Effects, names: Collection[AxisName] +) -> Effects: + return {e for e in effects + if not isinstance(e, NamedAxisEffect) or e.name not in names} + + +def remove_named_axis_effects( + jaxpr: Jaxpr, names: Collection[AxisName] +) -> Jaxpr: + if not names or not jaxpr.effects: + return jaxpr + return jaxpr.replace(effects=filter_named_axis_effects(jaxpr.effects, names)) + + ParamDict = dict[str, Any] AxisSubst = Callable[[AxisName], tuple[AxisName, ...]] @@ -2676,6 +2705,15 @@ class DuplicateAxisNameError(Exception): self.var = var self.eqn = None +def subst_axis_names_effects(effects: Set[Effect], subst: AxisSubst) -> Set[Effect]: + new_effects = set[Effect]() + for e in effects: + if isinstance(e, NamedAxisEffect): + new_effects.update(map(NamedAxisEffect, subst(e.name))) + else: + new_effects.add(e) + return new_effects + def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: dict[Var, Var]) -> Var: # Var identity is load-bearing, so we can't have duplicates! if isinstance(v, DropVar): return v @@ -2699,7 +2737,8 @@ def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: dict[Var, Var e.eqn = eqn raise params = subst_axis_names(eqn.primitive, eqn.params, subst) - return eqn.replace(invars=invars, outvars=outvars, params=params) + effects = subst_axis_names_effects(eqn.effects, subst) + return eqn.replace(invars=invars, outvars=outvars, params=params, effects=effects) def do_subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst): consts = None @@ -2711,16 +2750,14 @@ def do_subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst): constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars] # type: ignore[union-attr] eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns] # type: ignore[union-attr] outvars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars] # type: ignore[union-attr] - new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, jaxpr.effects) + effects = subst_axis_names_effects(jaxpr.effects, subst) + new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, effects) if consts is not None: return ClosedJaxpr(new_jaxpr, consts) return new_jaxpr -@weakref_lru_cache def used_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr): - subst = NameGatheringSubst() - do_subst_axis_names_jaxpr(jaxpr, subst) - return frozenset(subst.axis_names) + return {e.name for e in jaxpr.effects if isinstance(e, NamedAxisEffect)} def subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst): if isinstance(subst, NameGatheringSubst): # This is a common case, so we optimize it! @@ -2924,6 +2961,9 @@ def _check_jaxpr( raise JaxprTypeError( "Invalid `JaxprInputEffect`: must be present in jaxpr. " f"{jaxpr_effect} is not in {jaxpr.effects}.") + elif isinstance(eff, NamedAxisEffect): + # It is valid for a primitive to discharge the named axis effect. + continue elif eff not in jaxpr.effects: raise JaxprTypeError("Equation effect not present in jaxpr effects. " f"Equation effect: {eff}. " @@ -3077,7 +3117,7 @@ def _check_map(ctx_factory, prim, in_avals, params): out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval) if out_axis is not None else aval for aval, out_axis in zip(mapped_out_avals, out_axes)] - return out_avals, call_jaxpr.effects + return out_avals, filter_named_axis_effects(call_jaxpr.effects, {axis_name}) # ------------------- Jaxpr printed representation ------------------- diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index a65bad6ee..95734493c 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -2570,8 +2570,9 @@ def build_mlir_module_helper( platforms: Sequence[str], backend_or_name: str, axis_context: AxisContext) -> ir.Module: """Helper to generate pmap-style XLA computations for custom partitioners.""" - if closed_jaxpr.effects: - raise NotImplementedError + unlowerable_effects = lowerable_effects.filter_not_in(closed_jaxpr.effects) + if unlowerable_effects: + raise ValueError(f'Cannot lower jaxpr with effects: {closed_jaxpr.effects}') lowering_result = lower_jaxpr_to_module(name, closed_jaxpr, backend_or_name=backend_or_name, ordered_effects=[], name_stack=source_info_util.NameStack(), diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index ebdc01300..ef2a65434 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -380,10 +380,10 @@ class JaxprTrace(Trace['JaxprTracer']): for ax, a in zip(staged_out_axes, out_avals_mapped)] out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None) for a in out_avals] + effs = core.filter_named_axis_effects(jaxpr.effects, {params['axis_name']}) + src_info = source_info_util.current() eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), # type: ignore[arg-type] - out_tracers, primitive, staged_params, - jaxpr.effects, - source_info_util.current()) + out_tracers, primitive, staged_params, effs, src_info) for t in out_tracers: t.recipe = eqn return merge_lists(out_knowns, out_tracers, out_consts) @@ -2099,8 +2099,9 @@ class DynamicJaxprTrace(core.Trace): update_params = call_param_updaters.get(map_primitive) if update_params: new_params = update_params(new_params, [True] * len(tracers), len(consts)) + effs = core.filter_named_axis_effects(jaxpr.effects, {axis_name}) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive, - new_params, jaxpr.effects, source_info) + new_params, effs, source_info) self.frame.add_eqn(eqn) return out_tracers diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index e8ddc67b2..95c4c499e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -766,6 +766,7 @@ def lower_parallel_callable( axis_env = sharding_impls.AxisEnv( replicas.num_global_replicas, (axis_name,), (global_axis_size,)) name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap')) + jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name}) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) replicated_args = [axis is None for axis in in_axes] tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals), @@ -1258,8 +1259,8 @@ def _pmap_partial_eval_custom_res_maker(params_known, aval): def _pmap_dce_rule(used_outputs, eqn): # just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes - with maybe_extend_axis_env(eqn.params['axis_name'], - eqn.params['global_axis_size'], None): + axis_name = eqn.params["axis_name"] + with maybe_extend_axis_env(axis_name, eqn.params["global_axis_size"], None): new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) _, donated_invars = partition_list(used_inputs, eqn.params['donated_invars']) _, in_axes = partition_list(used_inputs, eqn.params['in_axes']) @@ -1270,10 +1271,11 @@ def _pmap_dce_rule(used_outputs, eqn): if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects: return used_inputs, None else: + effs = core.filter_named_axis_effects(new_jaxpr.effects, {axis_name}) new_eqn = pe.new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [v for v, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info) + eqn.primitive, new_params, effs, eqn.source_info) return used_inputs, new_eqn @@ -2238,6 +2240,7 @@ def lower_mesh_computation( axis_ctx = sharding_impls.ReplicaAxisContext(axis_env) num_replicas = mesh.devices.size num_partitions = 1 + jaxpr = core.remove_named_axis_effects(jaxpr, mesh.axis_names) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) module_name = f"{api_name}_{fun_name}" with core.extend_axis_env_nd(mesh.shape.items()): diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index eaaf08c16..7226b2649 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -23,15 +23,12 @@ import itertools import math import string -import numpy as np - from jax import tree_util - from jax._src import core from jax._src import dtypes from jax._src import sharding_impls from jax._src import util -from jax._src.core import ShapedArray, AxisName, raise_to_shaped +from jax._src.core import AxisName, ShapedArray, raise_to_shaped from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -41,8 +38,9 @@ from jax._src.lax import slicing from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.numpy import lax_numpy -from jax._src.util import ( - unzip2, canonicalize_axis, safe_map, safe_zip, moveaxis) +from jax._src.util import (canonicalize_axis, moveaxis, safe_map, safe_zip, + unzip2) +import numpy as np unsafe_map, map = map, safe_map # type: ignore @@ -709,21 +707,23 @@ def _allreduce_impl(pos_reducer, *args, axes, axis_index_groups): assert all(isinstance(axis, int) for axis in axes) return [pos_reducer(arg, axes) for arg in args] -def _allreduce_abstract_eval(*args, axes, axis_index_groups): +def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): # TODO(frostig,mattjj,jekbradbury): maybe check aval names here pos_axes = tuple(axis for axis in axes if isinstance(axis, int)) named_shapes = [arg.named_shape for arg in args] + named_axes = {axis for axis in axes if not isinstance(axis, int)} if axis_index_groups is None: - named_axes = {axis for axis in axes if not isinstance(axis, int)} named_shapes = [{name: size for name, size in arg.named_shape.items() if name not in named_axes} for arg in args] else: if len(pos_axes) != 0: raise ValueError(f"axis_index_groups can only be used with reductions over " f"named axes, but got: {axes}") - return [ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes), - arg.dtype, named_shape=named_shape) - for arg, named_shape in zip(args, named_shapes)] + out_avals = [ + ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes), + arg.dtype, named_shape=named_shape) + for arg, named_shape in zip(args, named_shapes)] + return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups): if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms): @@ -804,7 +804,7 @@ def _psum_transpose_rule(cts, *args, axes, axis_index_groups): psum_p = core.AxisPrimitive('psum') psum_p.multiple_results = True psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum)) -psum_p.def_abstract_eval(_allreduce_abstract_eval) +psum_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) mlir.register_lowering( psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum)) ad.deflinear2(psum_p, _psum_transpose_rule) @@ -840,7 +840,7 @@ def psum_bind(*args, axes, axis_index_groups): pmax_p = core.AxisPrimitive('pmax') pmax_p.multiple_results = True pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max)) -pmax_p.def_abstract_eval(_allreduce_abstract_eval) +pmax_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) mlir.register_lowering( pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max)) batching.primitive_batchers[pmax_p] = partial(_reduction_batcher, pmax_p) @@ -852,7 +852,7 @@ core.axis_substitution_rules[pmax_p] = partial(_subst_all_names_in_param, 'axes' pmin_p = core.AxisPrimitive('pmin') pmin_p.multiple_results = True pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min)) -pmin_p.def_abstract_eval(_allreduce_abstract_eval) +pmin_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval) mlir.register_lowering( pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min)) batching.primitive_batchers[pmin_p] = partial(_reduction_batcher, pmin_p) @@ -1060,17 +1060,25 @@ def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in, new_d -= 1 # We've removed 0th dimension, so new_d needs to be adjusted return x, new_d -def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_groups): + +def _all_to_all_effectful_abstract_eval( + x, axis_name, split_axis, concat_axis, axis_index_groups +): + if not isinstance(axis_name, (list, tuple)): + axis_name = (axis_name,) input_aval = raise_to_shaped(x) shape = list(input_aval.shape) axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0]) assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size) shape[split_axis] //= axis_size shape[concat_axis] *= axis_size - return input_aval.update(shape=tuple(shape), weak_type=False) + out_aval = input_aval.update(shape=tuple(shape), weak_type=False) + effects = {*map(core.NamedAxisEffect, axis_name)} + return out_aval, effects + all_to_all_p = core.AxisPrimitive('all_to_all') -all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval) +all_to_all_p.def_effectful_abstract_eval(_all_to_all_effectful_abstract_eval) mlir.register_lowering(all_to_all_p, _all_to_all_lowering) ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule) batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher @@ -1204,7 +1212,10 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, replica_groups=_replica_groups_hlo(replica_groups), **other_args).results -def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): + +def _all_gather_effectful_abstract_eval( + x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled +): if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) x_aval = raise_to_shaped(x) @@ -1215,7 +1226,10 @@ def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_ new_shape.insert(all_gather_dimension, axis_size) new_named_shape = {name: size for name, size in x_aval.named_shape.items() if name not in axis_name} - return x_aval.update(shape=new_shape, named_shape=new_named_shape) + out_aval = x_aval.update(shape=new_shape, named_shape=new_named_shape) + effects = {*map(core.NamedAxisEffect, axis_name)} + return out_aval, effects + def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): return (psum_scatter(cts, axis_name=axis_name, @@ -1264,7 +1278,7 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in, return y, batching.not_mapped all_gather_p = core.AxisPrimitive('all_gather') -all_gather_p.def_abstract_eval(_all_gather_abstract_eval) +all_gather_p.def_effectful_abstract_eval(_all_gather_effectful_abstract_eval) all_gather_p.def_impl(_all_gather_impl) mlir.register_lowering(all_gather_p, _all_gather_lowering) for p in ("cuda", "rocm", "tpu"): @@ -1327,9 +1341,9 @@ def _reduce_scatter_lowering( return [hlo.reshape(mlir.aval_to_ir_type(aval_out), op.result)] - -def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension, - axis_index_groups, axis_size, tiled): +def _reduce_scatter_effectful_abstract_eval( + x, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled +): if not isinstance(axis_name, (list, tuple)): axis_name = (axis_name,) x_aval = core.raise_to_shaped(x) @@ -1353,7 +1367,9 @@ def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension, for name, size in x_aval.named_shape.items() if name not in axis_name } - return x_aval.update(shape=new_shape, named_shape=new_named_shape) + out_aval = x_aval.update(shape=new_shape, named_shape=new_named_shape) + effects = {*map(core.NamedAxisEffect, axis_name)} + return out_aval, effects def _reduce_scatter_transpose_rule(cts, x, *, axis_name, scatter_dimension, @@ -1401,7 +1417,9 @@ def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in, reduce_scatter_p = core.AxisPrimitive("reduce_scatter") -reduce_scatter_p.def_abstract_eval(_reduce_scatter_abstract_eval) +reduce_scatter_p.def_effectful_abstract_eval( + _reduce_scatter_effectful_abstract_eval +) ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule) batching.primitive_batchers[reduce_scatter_p] = _reduce_scatter_batcher batching.axis_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective @@ -1537,13 +1555,15 @@ def _axis_index_lowering(ctx, *, axis_name): ] -def _axis_index_abstract_eval(*, axis_name): +def _axis_index_effectful_abstract_eval(*, axis_name): frame = core.axis_frame(axis_name) - return ShapedArray((), np.int32, named_shape={axis_name: frame.size}) + out_aval = ShapedArray((), np.int32, named_shape={axis_name: frame.size}) + return out_aval, {core.NamedAxisEffect(axis_name)} + axis_index_p = core.Primitive('axis_index') mlir.register_lowering(axis_index_p, _axis_index_lowering) -axis_index_p.def_abstract_eval(_axis_index_abstract_eval) +axis_index_p.def_effectful_abstract_eval(_axis_index_effectful_abstract_eval) core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param, 'axis_name') # Axis index doesn't get any arguments, so that the default bind would have no @@ -1585,8 +1605,11 @@ def _pdot_impl(x, y, *, axis_name, pos_contract, pos_batch, precision): if axis_name: raise NameError(f"unbound axis name: {axis_name[0]}") return lax.dot_general(x, y, (pos_contract, pos_batch), precision=precision) -@pdot_p.def_abstract_eval -def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch, precision): + +@pdot_p.def_effectful_abstract_eval +def _pdot_effectful_abstract_eval( + x, y, *, axis_name, pos_contract, pos_batch, precision +): # TODO(frostig,mattjj,jekbradbury): check inputs have given axis names? if not len(set(axis_name)) == len(axis_name): raise ValueError pos_aval = lax.dot_general_p.abstract_eval( @@ -1596,7 +1619,10 @@ def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch, precision): named_shape = {name: size for name, size in common_named_shape.items() if name not in axis_name} - return pos_aval.update(named_shape=named_shape) + out_aval = pos_aval.update(named_shape=named_shape) + effects = {*map(core.NamedAxisEffect, axis_name)} + return out_aval, effects + def _pdot_vmap_collective_rule(axis_size, frame_name, _, vals_in, dims_in, *, axis_name, pos_contract, pos_batch, precision): diff --git a/jax/_src/maps.py b/jax/_src/maps.py index b6433203b..bde12f3fc 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -937,7 +937,8 @@ def _typecheck_xmap( mapped_out_avals = [v.aval for v in call_jaxpr.outvars] out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes) for a, a_out_axes in zip(mapped_out_avals, out_axes)] - return out_avals, call_jaxpr.effects + effs = core.filter_named_axis_effects(call_jaxpr.effects, global_axis_sizes) + return out_avals, effs core.custom_typechecks[xmap_p] = _typecheck_xmap @@ -1033,8 +1034,9 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params): call_jaxpr=call_jaxpr) del new_params['out_axes_thunk'] del new_params['spmd_out_axes_thunk'] + effs = core.filter_named_axis_effects(call_jaxpr.effects, global_axis_sizes) eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive, - new_params, call_jaxpr.effects, source_info) + new_params, effs, source_info) self.frame.add_eqn(eqn) return out_tracers pe.DynamicJaxprTrace.process_xmap = _dynamic_jaxpr_process_xmap # type: ignore diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 9f6f9f5b4..71250a34c 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -426,7 +426,9 @@ def _shard_map_staging( in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals) main = trace.main with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()): - jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_) + jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic( + f, main, in_avals_ + ) out_avals_ = map(_check_shapedarray, genavals) _check_names(out_names_thunk(), out_avals_) in_rep = map(partial(_in_names_to_rep, mesh), in_names) @@ -445,8 +447,9 @@ def _shard_map_staging( params = dict(mesh=mesh, in_names=in_names_staged, out_names=tuple(out_names_thunk()), jaxpr=jaxpr, check_rep=check_rep, rewrite=rewrite, auto=auto) + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params, - jaxpr.effects, source_info) + effs, source_info) trace.frame.add_eqn(eqn) return out_tracers pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging @@ -495,7 +498,8 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names, "sufficiently replicated") out_avals_sharded = [x.aval for x in jaxpr.outvars] out_avals = map(partial(_unshard_aval, mesh), out_names, out_avals_sharded) - return out_avals, jaxpr.effects + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) + return out_avals, effs core.custom_typechecks[shard_map_p] = _shard_map_typecheck def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> set[AxisName]: @@ -1305,9 +1309,10 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names, out_avals = map(partial(_unshard_aval, mesh), unk_out_names, out_avals_sharded) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None) for a in out_avals] + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers), # type: ignore[arg-type] out_tracers, shard_map_p, unk_params, - jaxpr.effects, source_info_util.current()) + effs, source_info_util.current()) for t in out_tracers: t.recipe = eqn return pe.merge_lists(out_knowns, out_tracers, out_consts) pe.JaxprTrace.process_shard_map = _shard_map_partial_eval @@ -1345,8 +1350,9 @@ def _shard_map_partial_eval_post_process( for a in out_avals] name_stack = trace._current_truncated_name_stack() source = source_info_util.current().replace(name_stack=name_stack) + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers), out_tracers, - shard_map_p, staged_params, jaxpr.effects, source) + shard_map_p, staged_params, effs, source) for t in out_tracers: t.recipe = eqn return merge_lists(out_knowns, out_tracers, out_consts) @@ -1456,6 +1462,8 @@ def _partial_eval_jaxpr_custom_rule( with core.extend_axis_env_nd(eqn.params['mesh'].shape.items()): jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which) jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged) + jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names) + jaxpr_staged = core.remove_named_axis_effects(jaxpr_staged, mesh.axis_names) ins_known, _ = partition_list(unks_in, eqn.invars) out_binders_known, _ = partition_list(unks_out, eqn.outvars) _, ins_staged = partition_list(inst_in, eqn.invars) @@ -1535,7 +1543,8 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, # TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule? def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: - with core.extend_axis_env_nd(eqn.params['mesh'].shape.items()): + mesh = eqn.params["mesh"] + with core.extend_axis_env_nd(mesh.shape.items()): jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects: return used_inputs, None @@ -1544,10 +1553,11 @@ def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn _, out_names = partition_list(used_outputs, eqn.params['out_names']) new_params = dict(eqn.params, jaxpr=jaxpr, in_names=tuple(in_names), out_names=tuple(out_names)) + effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names) new_eqn = pe.new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [x for x, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, jaxpr.effects, eqn.source_info) + eqn.primitive, new_params, effs, eqn.source_info) return used_inputs, new_eqn pe.dce_rules[shard_map_p] = _shard_map_dce