mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Axis names are now tracked via an effect
This allows propagating the names bottom up -- from equations to the jaxpr, instead of "discovering" them top-down by traversing (and rebuilding) the jaxpr via core.subst_axis_names. PiperOrigin-RevId: 612416803
This commit is contained in:
parent
2dd5e9e180
commit
5283d4b4a5
@ -569,6 +569,10 @@ def xla_computation(fun: Callable,
|
||||
stack.enter_context(core.extend_axis_env(axis_name, size, None))
|
||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals)
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
if axis_env:
|
||||
jaxpr = core.remove_named_axis_effects(
|
||||
jaxpr, {axis_name for axis_name, _ in axis_env}
|
||||
)
|
||||
axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr))
|
||||
ordered_effects = list(
|
||||
effects.ordered_effects.filter_in(jaxpr.effects))
|
||||
|
@ -15,8 +15,9 @@ from __future__ import annotations
|
||||
|
||||
import collections # noqa: F401
|
||||
from collections import Counter, defaultdict, deque, namedtuple
|
||||
from collections.abc import (Generator, Hashable, Iterable, Iterator, Sequence,
|
||||
MutableSet, MutableMapping)
|
||||
from collections.abc import (Collection, Generator, Hashable, Iterable,
|
||||
Iterator, Set, Sequence, MutableSet,
|
||||
MutableMapping)
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
import functools
|
||||
@ -2637,6 +2638,34 @@ def axis_frame(axis_name: AxisName, main_trace: MainTrace | None = None
|
||||
f'by pmap) are available to collective operations: {named_axes}')
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NamedAxisEffect(effects.Effect):
|
||||
"""A side-effect introducing a new named axis into the current scope."""
|
||||
|
||||
name: AxisName
|
||||
|
||||
|
||||
effects.control_flow_allowed_effects.add_type(NamedAxisEffect)
|
||||
effects.custom_derivatives_allowed_effects.add_type(NamedAxisEffect)
|
||||
effects.lowerable_effects.add_type(NamedAxisEffect)
|
||||
effects.remat_allowed_effects.add_type(NamedAxisEffect)
|
||||
|
||||
|
||||
def filter_named_axis_effects(
|
||||
effects: Effects, names: Collection[AxisName]
|
||||
) -> Effects:
|
||||
return {e for e in effects
|
||||
if not isinstance(e, NamedAxisEffect) or e.name not in names}
|
||||
|
||||
|
||||
def remove_named_axis_effects(
|
||||
jaxpr: Jaxpr, names: Collection[AxisName]
|
||||
) -> Jaxpr:
|
||||
if not names or not jaxpr.effects:
|
||||
return jaxpr
|
||||
return jaxpr.replace(effects=filter_named_axis_effects(jaxpr.effects, names))
|
||||
|
||||
|
||||
ParamDict = dict[str, Any]
|
||||
AxisSubst = Callable[[AxisName], tuple[AxisName, ...]]
|
||||
|
||||
@ -2676,6 +2705,15 @@ class DuplicateAxisNameError(Exception):
|
||||
self.var = var
|
||||
self.eqn = None
|
||||
|
||||
def subst_axis_names_effects(effects: Set[Effect], subst: AxisSubst) -> Set[Effect]:
|
||||
new_effects = set[Effect]()
|
||||
for e in effects:
|
||||
if isinstance(e, NamedAxisEffect):
|
||||
new_effects.update(map(NamedAxisEffect, subst(e.name)))
|
||||
else:
|
||||
new_effects.add(e)
|
||||
return new_effects
|
||||
|
||||
def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: dict[Var, Var]) -> Var:
|
||||
# Var identity is load-bearing, so we can't have duplicates!
|
||||
if isinstance(v, DropVar): return v
|
||||
@ -2699,7 +2737,8 @@ def subst_axis_names_eqn(eqn: JaxprEqn, subst: AxisSubst, var_map: dict[Var, Var
|
||||
e.eqn = eqn
|
||||
raise
|
||||
params = subst_axis_names(eqn.primitive, eqn.params, subst)
|
||||
return eqn.replace(invars=invars, outvars=outvars, params=params)
|
||||
effects = subst_axis_names_effects(eqn.effects, subst)
|
||||
return eqn.replace(invars=invars, outvars=outvars, params=params, effects=effects)
|
||||
|
||||
def do_subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst):
|
||||
consts = None
|
||||
@ -2711,16 +2750,14 @@ def do_subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst):
|
||||
constvars = [subst_axis_names_var(v, subst, var_map) for v in jaxpr.constvars] # type: ignore[union-attr]
|
||||
eqns = [subst_axis_names_eqn(eqn, subst, var_map) for eqn in jaxpr.eqns] # type: ignore[union-attr]
|
||||
outvars: list[Atom] = [v if isinstance(v, Literal) else var_map[v] for v in jaxpr.outvars] # type: ignore[union-attr]
|
||||
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, jaxpr.effects)
|
||||
effects = subst_axis_names_effects(jaxpr.effects, subst)
|
||||
new_jaxpr = Jaxpr(constvars, invars, outvars, eqns, effects)
|
||||
if consts is not None:
|
||||
return ClosedJaxpr(new_jaxpr, consts)
|
||||
return new_jaxpr
|
||||
|
||||
@weakref_lru_cache
|
||||
def used_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr):
|
||||
subst = NameGatheringSubst()
|
||||
do_subst_axis_names_jaxpr(jaxpr, subst)
|
||||
return frozenset(subst.axis_names)
|
||||
return {e.name for e in jaxpr.effects if isinstance(e, NamedAxisEffect)}
|
||||
|
||||
def subst_axis_names_jaxpr(jaxpr: Jaxpr | ClosedJaxpr, subst: AxisSubst):
|
||||
if isinstance(subst, NameGatheringSubst): # This is a common case, so we optimize it!
|
||||
@ -2924,6 +2961,9 @@ def _check_jaxpr(
|
||||
raise JaxprTypeError(
|
||||
"Invalid `JaxprInputEffect`: must be present in jaxpr. "
|
||||
f"{jaxpr_effect} is not in {jaxpr.effects}.")
|
||||
elif isinstance(eff, NamedAxisEffect):
|
||||
# It is valid for a primitive to discharge the named axis effect.
|
||||
continue
|
||||
elif eff not in jaxpr.effects:
|
||||
raise JaxprTypeError("Equation effect not present in jaxpr effects. "
|
||||
f"Equation effect: {eff}. "
|
||||
@ -3077,7 +3117,7 @@ def _check_map(ctx_factory, prim, in_avals, params):
|
||||
out_avals = [unmapped_aval(axis_size, axis_name, out_axis, aval)
|
||||
if out_axis is not None else aval
|
||||
for aval, out_axis in zip(mapped_out_avals, out_axes)]
|
||||
return out_avals, call_jaxpr.effects
|
||||
return out_avals, filter_named_axis_effects(call_jaxpr.effects, {axis_name})
|
||||
|
||||
|
||||
# ------------------- Jaxpr printed representation -------------------
|
||||
|
@ -2570,8 +2570,9 @@ def build_mlir_module_helper(
|
||||
platforms: Sequence[str],
|
||||
backend_or_name: str, axis_context: AxisContext) -> ir.Module:
|
||||
"""Helper to generate pmap-style XLA computations for custom partitioners."""
|
||||
if closed_jaxpr.effects:
|
||||
raise NotImplementedError
|
||||
unlowerable_effects = lowerable_effects.filter_not_in(closed_jaxpr.effects)
|
||||
if unlowerable_effects:
|
||||
raise ValueError(f'Cannot lower jaxpr with effects: {closed_jaxpr.effects}')
|
||||
lowering_result = lower_jaxpr_to_module(name, closed_jaxpr,
|
||||
backend_or_name=backend_or_name, ordered_effects=[],
|
||||
name_stack=source_info_util.NameStack(),
|
||||
|
@ -380,10 +380,10 @@ class JaxprTrace(Trace['JaxprTracer']):
|
||||
for ax, a in zip(staged_out_axes, out_avals_mapped)]
|
||||
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
|
||||
for a in out_avals]
|
||||
effs = core.filter_named_axis_effects(jaxpr.effects, {params['axis_name']})
|
||||
src_info = source_info_util.current()
|
||||
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), # type: ignore[arg-type]
|
||||
out_tracers, primitive, staged_params,
|
||||
jaxpr.effects,
|
||||
source_info_util.current())
|
||||
out_tracers, primitive, staged_params, effs, src_info)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
|
||||
return merge_lists(out_knowns, out_tracers, out_consts)
|
||||
@ -2099,8 +2099,9 @@ class DynamicJaxprTrace(core.Trace):
|
||||
update_params = call_param_updaters.get(map_primitive)
|
||||
if update_params:
|
||||
new_params = update_params(new_params, [True] * len(tracers), len(consts))
|
||||
effs = core.filter_named_axis_effects(jaxpr.effects, {axis_name})
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
|
||||
new_params, jaxpr.effects, source_info)
|
||||
new_params, effs, source_info)
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
|
||||
|
@ -766,6 +766,7 @@ def lower_parallel_callable(
|
||||
axis_env = sharding_impls.AxisEnv(
|
||||
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap'))
|
||||
jaxpr = core.remove_named_axis_effects(jaxpr, {axis_name})
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
replicated_args = [axis is None for axis in in_axes]
|
||||
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
|
||||
@ -1258,8 +1259,8 @@ def _pmap_partial_eval_custom_res_maker(params_known, aval):
|
||||
|
||||
def _pmap_dce_rule(used_outputs, eqn):
|
||||
# just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes
|
||||
with maybe_extend_axis_env(eqn.params['axis_name'],
|
||||
eqn.params['global_axis_size'], None):
|
||||
axis_name = eqn.params["axis_name"]
|
||||
with maybe_extend_axis_env(axis_name, eqn.params["global_axis_size"], None):
|
||||
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
|
||||
_, donated_invars = partition_list(used_inputs, eqn.params['donated_invars'])
|
||||
_, in_axes = partition_list(used_inputs, eqn.params['in_axes'])
|
||||
@ -1270,10 +1271,11 @@ def _pmap_dce_rule(used_outputs, eqn):
|
||||
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
|
||||
return used_inputs, None
|
||||
else:
|
||||
effs = core.filter_named_axis_effects(new_jaxpr.effects, {axis_name})
|
||||
new_eqn = pe.new_jaxpr_eqn(
|
||||
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
||||
[v for v, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info)
|
||||
eqn.primitive, new_params, effs, eqn.source_info)
|
||||
return used_inputs, new_eqn
|
||||
|
||||
|
||||
@ -2238,6 +2240,7 @@ def lower_mesh_computation(
|
||||
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
|
||||
num_replicas = mesh.devices.size
|
||||
num_partitions = 1
|
||||
jaxpr = core.remove_named_axis_effects(jaxpr, mesh.axis_names)
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
module_name = f"{api_name}_{fun_name}"
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
|
@ -23,15 +23,12 @@ import itertools
|
||||
import math
|
||||
import string
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import tree_util
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import util
|
||||
from jax._src.core import ShapedArray, AxisName, raise_to_shaped
|
||||
from jax._src.core import AxisName, ShapedArray, raise_to_shaped
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -41,8 +38,9 @@ from jax._src.lax import slicing
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.util import (
|
||||
unzip2, canonicalize_axis, safe_map, safe_zip, moveaxis)
|
||||
from jax._src.util import (canonicalize_axis, moveaxis, safe_map, safe_zip,
|
||||
unzip2)
|
||||
import numpy as np
|
||||
|
||||
unsafe_map, map = map, safe_map # type: ignore
|
||||
|
||||
@ -709,21 +707,23 @@ def _allreduce_impl(pos_reducer, *args, axes, axis_index_groups):
|
||||
assert all(isinstance(axis, int) for axis in axes)
|
||||
return [pos_reducer(arg, axes) for arg in args]
|
||||
|
||||
def _allreduce_abstract_eval(*args, axes, axis_index_groups):
|
||||
def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
|
||||
# TODO(frostig,mattjj,jekbradbury): maybe check aval names here
|
||||
pos_axes = tuple(axis for axis in axes if isinstance(axis, int))
|
||||
named_shapes = [arg.named_shape for arg in args]
|
||||
if axis_index_groups is None:
|
||||
named_axes = {axis for axis in axes if not isinstance(axis, int)}
|
||||
if axis_index_groups is None:
|
||||
named_shapes = [{name: size for name, size in arg.named_shape.items()
|
||||
if name not in named_axes} for arg in args]
|
||||
else:
|
||||
if len(pos_axes) != 0:
|
||||
raise ValueError(f"axis_index_groups can only be used with reductions over "
|
||||
f"named axes, but got: {axes}")
|
||||
return [ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes),
|
||||
out_avals = [
|
||||
ShapedArray(lax._reduce_op_shape_rule(raise_to_shaped(arg), axes=pos_axes),
|
||||
arg.dtype, named_shape=named_shape)
|
||||
for arg, named_shape in zip(args, named_shapes)]
|
||||
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}
|
||||
|
||||
def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
if axis_index_groups is not None and ("tpu" in ctx.module_context.platforms):
|
||||
@ -804,7 +804,7 @@ def _psum_transpose_rule(cts, *args, axes, axis_index_groups):
|
||||
psum_p = core.AxisPrimitive('psum')
|
||||
psum_p.multiple_results = True
|
||||
psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum))
|
||||
psum_p.def_abstract_eval(_allreduce_abstract_eval)
|
||||
psum_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
|
||||
mlir.register_lowering(
|
||||
psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum))
|
||||
ad.deflinear2(psum_p, _psum_transpose_rule)
|
||||
@ -840,7 +840,7 @@ def psum_bind(*args, axes, axis_index_groups):
|
||||
pmax_p = core.AxisPrimitive('pmax')
|
||||
pmax_p.multiple_results = True
|
||||
pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max))
|
||||
pmax_p.def_abstract_eval(_allreduce_abstract_eval)
|
||||
pmax_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
|
||||
mlir.register_lowering(
|
||||
pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max))
|
||||
batching.primitive_batchers[pmax_p] = partial(_reduction_batcher, pmax_p)
|
||||
@ -852,7 +852,7 @@ core.axis_substitution_rules[pmax_p] = partial(_subst_all_names_in_param, 'axes'
|
||||
pmin_p = core.AxisPrimitive('pmin')
|
||||
pmin_p.multiple_results = True
|
||||
pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min))
|
||||
pmin_p.def_abstract_eval(_allreduce_abstract_eval)
|
||||
pmin_p.def_effectful_abstract_eval(_allreduce_effectful_abstract_eval)
|
||||
mlir.register_lowering(
|
||||
pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min))
|
||||
batching.primitive_batchers[pmin_p] = partial(_reduction_batcher, pmin_p)
|
||||
@ -1060,17 +1060,25 @@ def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in,
|
||||
new_d -= 1 # We've removed 0th dimension, so new_d needs to be adjusted
|
||||
return x, new_d
|
||||
|
||||
def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_groups):
|
||||
|
||||
def _all_to_all_effectful_abstract_eval(
|
||||
x, axis_name, split_axis, concat_axis, axis_index_groups
|
||||
):
|
||||
if not isinstance(axis_name, (list, tuple)):
|
||||
axis_name = (axis_name,)
|
||||
input_aval = raise_to_shaped(x)
|
||||
shape = list(input_aval.shape)
|
||||
axis_size = psum(1, axis_name) if axis_index_groups is None else len(axis_index_groups[0])
|
||||
assert shape[split_axis] % axis_size == 0, (shape[split_axis], axis_size)
|
||||
shape[split_axis] //= axis_size
|
||||
shape[concat_axis] *= axis_size
|
||||
return input_aval.update(shape=tuple(shape), weak_type=False)
|
||||
out_aval = input_aval.update(shape=tuple(shape), weak_type=False)
|
||||
effects = {*map(core.NamedAxisEffect, axis_name)}
|
||||
return out_aval, effects
|
||||
|
||||
|
||||
all_to_all_p = core.AxisPrimitive('all_to_all')
|
||||
all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval)
|
||||
all_to_all_p.def_effectful_abstract_eval(_all_to_all_effectful_abstract_eval)
|
||||
mlir.register_lowering(all_to_all_p, _all_to_all_lowering)
|
||||
ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule)
|
||||
batching.primitive_batchers[all_to_all_p] = _all_to_all_batcher
|
||||
@ -1204,7 +1212,10 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
|
||||
replica_groups=_replica_groups_hlo(replica_groups),
|
||||
**other_args).results
|
||||
|
||||
def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
|
||||
|
||||
def _all_gather_effectful_abstract_eval(
|
||||
x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled
|
||||
):
|
||||
if not isinstance(axis_name, (list, tuple)):
|
||||
axis_name = (axis_name,)
|
||||
x_aval = raise_to_shaped(x)
|
||||
@ -1215,7 +1226,10 @@ def _all_gather_abstract_eval(x, *, all_gather_dimension, axis_name, axis_index_
|
||||
new_shape.insert(all_gather_dimension, axis_size)
|
||||
new_named_shape = {name: size for name, size in x_aval.named_shape.items()
|
||||
if name not in axis_name}
|
||||
return x_aval.update(shape=new_shape, named_shape=new_named_shape)
|
||||
out_aval = x_aval.update(shape=new_shape, named_shape=new_named_shape)
|
||||
effects = {*map(core.NamedAxisEffect, axis_name)}
|
||||
return out_aval, effects
|
||||
|
||||
|
||||
def _all_gather_transpose_rule(cts, x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled):
|
||||
return (psum_scatter(cts, axis_name=axis_name,
|
||||
@ -1264,7 +1278,7 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in,
|
||||
return y, batching.not_mapped
|
||||
|
||||
all_gather_p = core.AxisPrimitive('all_gather')
|
||||
all_gather_p.def_abstract_eval(_all_gather_abstract_eval)
|
||||
all_gather_p.def_effectful_abstract_eval(_all_gather_effectful_abstract_eval)
|
||||
all_gather_p.def_impl(_all_gather_impl)
|
||||
mlir.register_lowering(all_gather_p, _all_gather_lowering)
|
||||
for p in ("cuda", "rocm", "tpu"):
|
||||
@ -1327,9 +1341,9 @@ def _reduce_scatter_lowering(
|
||||
return [hlo.reshape(mlir.aval_to_ir_type(aval_out), op.result)]
|
||||
|
||||
|
||||
|
||||
def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
|
||||
axis_index_groups, axis_size, tiled):
|
||||
def _reduce_scatter_effectful_abstract_eval(
|
||||
x, *, axis_name, scatter_dimension, axis_index_groups, axis_size, tiled
|
||||
):
|
||||
if not isinstance(axis_name, (list, tuple)):
|
||||
axis_name = (axis_name,)
|
||||
x_aval = core.raise_to_shaped(x)
|
||||
@ -1353,7 +1367,9 @@ def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,
|
||||
for name, size in x_aval.named_shape.items()
|
||||
if name not in axis_name
|
||||
}
|
||||
return x_aval.update(shape=new_shape, named_shape=new_named_shape)
|
||||
out_aval = x_aval.update(shape=new_shape, named_shape=new_named_shape)
|
||||
effects = {*map(core.NamedAxisEffect, axis_name)}
|
||||
return out_aval, effects
|
||||
|
||||
|
||||
def _reduce_scatter_transpose_rule(cts, x, *, axis_name, scatter_dimension,
|
||||
@ -1401,7 +1417,9 @@ def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in,
|
||||
|
||||
|
||||
reduce_scatter_p = core.AxisPrimitive("reduce_scatter")
|
||||
reduce_scatter_p.def_abstract_eval(_reduce_scatter_abstract_eval)
|
||||
reduce_scatter_p.def_effectful_abstract_eval(
|
||||
_reduce_scatter_effectful_abstract_eval
|
||||
)
|
||||
ad.deflinear2(reduce_scatter_p, _reduce_scatter_transpose_rule)
|
||||
batching.primitive_batchers[reduce_scatter_p] = _reduce_scatter_batcher
|
||||
batching.axis_primitive_batchers[reduce_scatter_p] = _reduce_scatter_collective
|
||||
@ -1537,13 +1555,15 @@ def _axis_index_lowering(ctx, *, axis_name):
|
||||
]
|
||||
|
||||
|
||||
def _axis_index_abstract_eval(*, axis_name):
|
||||
def _axis_index_effectful_abstract_eval(*, axis_name):
|
||||
frame = core.axis_frame(axis_name)
|
||||
return ShapedArray((), np.int32, named_shape={axis_name: frame.size})
|
||||
out_aval = ShapedArray((), np.int32, named_shape={axis_name: frame.size})
|
||||
return out_aval, {core.NamedAxisEffect(axis_name)}
|
||||
|
||||
|
||||
axis_index_p = core.Primitive('axis_index')
|
||||
mlir.register_lowering(axis_index_p, _axis_index_lowering)
|
||||
axis_index_p.def_abstract_eval(_axis_index_abstract_eval)
|
||||
axis_index_p.def_effectful_abstract_eval(_axis_index_effectful_abstract_eval)
|
||||
core.axis_substitution_rules[axis_index_p] = partial(_subst_all_names_in_param, 'axis_name')
|
||||
|
||||
# Axis index doesn't get any arguments, so that the default bind would have no
|
||||
@ -1585,8 +1605,11 @@ def _pdot_impl(x, y, *, axis_name, pos_contract, pos_batch, precision):
|
||||
if axis_name: raise NameError(f"unbound axis name: {axis_name[0]}")
|
||||
return lax.dot_general(x, y, (pos_contract, pos_batch), precision=precision)
|
||||
|
||||
@pdot_p.def_abstract_eval
|
||||
def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch, precision):
|
||||
|
||||
@pdot_p.def_effectful_abstract_eval
|
||||
def _pdot_effectful_abstract_eval(
|
||||
x, y, *, axis_name, pos_contract, pos_batch, precision
|
||||
):
|
||||
# TODO(frostig,mattjj,jekbradbury): check inputs have given axis names?
|
||||
if not len(set(axis_name)) == len(axis_name): raise ValueError
|
||||
pos_aval = lax.dot_general_p.abstract_eval(
|
||||
@ -1596,7 +1619,10 @@ def _pdot_abstract_eval(x, y, *, axis_name, pos_contract, pos_batch, precision):
|
||||
named_shape = {name: size
|
||||
for name, size in common_named_shape.items()
|
||||
if name not in axis_name}
|
||||
return pos_aval.update(named_shape=named_shape)
|
||||
out_aval = pos_aval.update(named_shape=named_shape)
|
||||
effects = {*map(core.NamedAxisEffect, axis_name)}
|
||||
return out_aval, effects
|
||||
|
||||
|
||||
def _pdot_vmap_collective_rule(axis_size, frame_name, _, vals_in, dims_in, *, axis_name,
|
||||
pos_contract, pos_batch, precision):
|
||||
|
@ -937,7 +937,8 @@ def _typecheck_xmap(
|
||||
mapped_out_avals = [v.aval for v in call_jaxpr.outvars]
|
||||
out_avals = [_insert_aval_axes(a, a_out_axes, local_axis_sizes)
|
||||
for a, a_out_axes in zip(mapped_out_avals, out_axes)]
|
||||
return out_avals, call_jaxpr.effects
|
||||
effs = core.filter_named_axis_effects(call_jaxpr.effects, global_axis_sizes)
|
||||
return out_avals, effs
|
||||
core.custom_typechecks[xmap_p] = _typecheck_xmap
|
||||
|
||||
|
||||
@ -1033,8 +1034,9 @@ def _dynamic_jaxpr_process_xmap(self, primitive, f, tracers, params):
|
||||
call_jaxpr=call_jaxpr)
|
||||
del new_params['out_axes_thunk']
|
||||
del new_params['spmd_out_axes_thunk']
|
||||
effs = core.filter_named_axis_effects(call_jaxpr.effects, global_axis_sizes)
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, primitive,
|
||||
new_params, call_jaxpr.effects, source_info)
|
||||
new_params, effs, source_info)
|
||||
self.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
pe.DynamicJaxprTrace.process_xmap = _dynamic_jaxpr_process_xmap # type: ignore
|
||||
|
@ -426,7 +426,9 @@ def _shard_map_staging(
|
||||
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
|
||||
main = trace.main
|
||||
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
|
||||
jaxpr, genavals, consts, () = pe.trace_to_subjaxpr_dynamic(
|
||||
f, main, in_avals_
|
||||
)
|
||||
out_avals_ = map(_check_shapedarray, genavals)
|
||||
_check_names(out_names_thunk(), out_avals_)
|
||||
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
|
||||
@ -445,8 +447,9 @@ def _shard_map_staging(
|
||||
params = dict(mesh=mesh, in_names=in_names_staged,
|
||||
out_names=tuple(out_names_thunk()), jaxpr=jaxpr,
|
||||
check_rep=check_rep, rewrite=rewrite, auto=auto)
|
||||
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
|
||||
eqn = pe.new_jaxpr_eqn([*constvars, *invars], outvars, prim, params,
|
||||
jaxpr.effects, source_info)
|
||||
effs, source_info)
|
||||
trace.frame.add_eqn(eqn)
|
||||
return out_tracers
|
||||
pe.DynamicJaxprTrace.process_shard_map = _shard_map_staging
|
||||
@ -495,7 +498,8 @@ def _shard_map_typecheck(_, *in_atoms, jaxpr, mesh, in_names, out_names,
|
||||
"sufficiently replicated")
|
||||
out_avals_sharded = [x.aval for x in jaxpr.outvars]
|
||||
out_avals = map(partial(_unshard_aval, mesh), out_names, out_avals_sharded)
|
||||
return out_avals, jaxpr.effects
|
||||
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
|
||||
return out_avals, effs
|
||||
core.custom_typechecks[shard_map_p] = _shard_map_typecheck
|
||||
|
||||
def _in_names_to_rep(mesh: Mesh, names: AxisNames) -> set[AxisName]:
|
||||
@ -1305,9 +1309,10 @@ def _shard_map_partial_eval(trace, shard_map_p, f, tracers, mesh, in_names,
|
||||
out_avals = map(partial(_unshard_aval, mesh), unk_out_names, out_avals_sharded)
|
||||
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
|
||||
for a in out_avals]
|
||||
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
|
||||
eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers, *unk_arg_tracers), # type: ignore[arg-type]
|
||||
out_tracers, shard_map_p, unk_params,
|
||||
jaxpr.effects, source_info_util.current())
|
||||
effs, source_info_util.current())
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return pe.merge_lists(out_knowns, out_tracers, out_consts)
|
||||
pe.JaxprTrace.process_shard_map = _shard_map_partial_eval
|
||||
@ -1345,8 +1350,9 @@ def _shard_map_partial_eval_post_process(
|
||||
for a in out_avals]
|
||||
name_stack = trace._current_truncated_name_stack()
|
||||
source = source_info_util.current().replace(name_stack=name_stack)
|
||||
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
|
||||
eqn = pe.new_eqn_recipe((*const_tracers, *env_tracers), out_tracers,
|
||||
shard_map_p, staged_params, jaxpr.effects, source)
|
||||
shard_map_p, staged_params, effs, source)
|
||||
for t in out_tracers: t.recipe = eqn
|
||||
return merge_lists(out_knowns, out_tracers, out_consts)
|
||||
|
||||
@ -1456,6 +1462,8 @@ def _partial_eval_jaxpr_custom_rule(
|
||||
with core.extend_axis_env_nd(eqn.params['mesh'].shape.items()):
|
||||
jaxpr_known = pe.prune_jaxpr_outputs(jaxpr_known, [True] * num_out_primals + which)
|
||||
jaxpr_known, jaxpr_staged = _add_reshapes(which, jaxpr_known, jaxpr_staged)
|
||||
jaxpr_known = core.remove_named_axis_effects(jaxpr_known, mesh.axis_names)
|
||||
jaxpr_staged = core.remove_named_axis_effects(jaxpr_staged, mesh.axis_names)
|
||||
ins_known, _ = partition_list(unks_in, eqn.invars)
|
||||
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
|
||||
_, ins_staged = partition_list(inst_in, eqn.invars)
|
||||
@ -1535,7 +1543,8 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,
|
||||
# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule?
|
||||
def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn
|
||||
) -> tuple[list[bool], core.JaxprEqn | None]:
|
||||
with core.extend_axis_env_nd(eqn.params['mesh'].shape.items()):
|
||||
mesh = eqn.params["mesh"]
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
|
||||
if not any(used_inputs) and not any(used_outputs) and not jaxpr.effects:
|
||||
return used_inputs, None
|
||||
@ -1544,10 +1553,11 @@ def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn
|
||||
_, out_names = partition_list(used_outputs, eqn.params['out_names'])
|
||||
new_params = dict(eqn.params, jaxpr=jaxpr, in_names=tuple(in_names),
|
||||
out_names=tuple(out_names))
|
||||
effs = core.filter_named_axis_effects(jaxpr.effects, mesh.axis_names)
|
||||
new_eqn = pe.new_jaxpr_eqn(
|
||||
[v for v, used in zip(eqn.invars, used_inputs) if used],
|
||||
[x for x, used in zip(eqn.outvars, used_outputs) if used],
|
||||
eqn.primitive, new_params, jaxpr.effects, eqn.source_info)
|
||||
eqn.primitive, new_params, effs, eqn.source_info)
|
||||
return used_inputs, new_eqn
|
||||
pe.dce_rules[shard_map_p] = _shard_map_dce
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user