diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 06c2fe31a..d66145d68 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1633,17 +1633,18 @@ class JaxprStackFrame: def add_eqn(self, eqn: core.JaxprEqn): self.eqns.append(eqn) - def to_jaxpr(self, out_tracers): + def to_jaxpr(self, out_tracers: Sequence[Tracer]) -> Tuple[Jaxpr, List[Any]]: # It's not necessary, but we keep the tracer-to-var mapping injective: assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values())) outvars = [self.tracer_to_var[id(t)] for t in out_tracers] + constvals: Sequence[Any] constvars, constvals = unzip2(self.constvar_to_val.items()) jaxpr_effects = make_jaxpr_effects(constvars, self.invars, outvars, self.eqns) jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects) jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals) jaxpr, constvals = _inline_literals(jaxpr, constvals) - return jaxpr, constvals + return jaxpr, list(constvals) def to_jaxpr2(self, out_tracers): # It's not necessary, but we keep the tracer-to-var mapping injective: @@ -1684,7 +1685,8 @@ class JaxprStackFrame: const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars] return invar_positions, const_eqns -def _const_folding_and_forwarding(jaxpr, constvals): +def _const_folding_and_forwarding( + jaxpr: Jaxpr, constvals: Sequence[Any]) -> Tuple[Jaxpr, Tuple[Any, ...]]: consts: Dict[Var, Any] = dict(zip(jaxpr.constvars, constvals)) var_subs: Dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined new_eqns = [] @@ -1728,7 +1730,10 @@ ForwardingRule = Callable[[JaxprEqn], Tuple[List[Optional[Var]], Optional[JaxprEqn]]] forwarding_rules: Dict[Primitive, ForwardingRule] = {} -def _inline_literals(jaxpr, constvals): + +def _inline_literals( + jaxpr: Jaxpr, constvals: Sequence[Any] +) -> Tuple[Jaxpr, List[Any]]: # This function also prunes unused constants and inserts `dropvar` symbols. input_effects = {eff for eff in jaxpr.effects if isinstance(eff, effects.JaxprInputEffect)} @@ -1770,6 +1775,7 @@ def _inline_literals(jaxpr, constvals): jaxpr_effects, jaxpr.debug_info) return new_jaxpr, new_constvals + class DynamicJaxprTrace(core.Trace): __slots__ = [] # type: ignore @@ -2108,12 +2114,15 @@ def result_info(dbg: DebugInfo) -> Optional[List[KeyPath]]: else: return [path for path, _ in generate_key_paths(dummy_result)] + @profiler.annotate_function -def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, - in_avals: Sequence[AbstractValue], - debug_info: Optional[DebugInfo] = None, - *, - keep_inputs: Optional[List[bool]] = None): +def trace_to_jaxpr_dynamic( + fun: lu.WrappedFun, + in_avals: Sequence[AbstractValue], + debug_info: Optional[DebugInfo] = None, + *, + keep_inputs: Optional[List[bool]] = None, +) -> Tuple[Jaxpr, List[AbstractValue], List[Any]]: with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore main.jaxpr_stack = () # type: ignore jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( @@ -2121,10 +2130,15 @@ def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, del main, fun return jaxpr, out_avals, consts -def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace, - in_avals: Sequence[AbstractValue], *, - keep_inputs: Optional[Sequence[bool]] = None, - debug_info: Optional[DebugInfo] = None): + +def trace_to_subjaxpr_dynamic( + fun: lu.WrappedFun, + main: core.MainTrace, + in_avals: Sequence[AbstractValue], + *, + keep_inputs: Optional[Sequence[bool]] = None, + debug_info: Optional[DebugInfo] = None, +) -> Tuple[Jaxpr, List[AbstractValue], List[Any]]: keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs frame = JaxprStackFrame() @@ -2178,11 +2192,14 @@ def extend_jaxpr_stack(main, frame): assert frame is main.jaxpr_stack[-1] main.jaxpr_stack = main.jaxpr_stack[:-1] + @profiler.annotate_function -def trace_to_jaxpr_final(fun: lu.WrappedFun, - in_avals: Sequence[AbstractValue], - debug_info: Optional[DebugInfo] = None, - keep_inputs: Optional[Sequence[bool]] = None): +def trace_to_jaxpr_final( + fun: lu.WrappedFun, + in_avals: Sequence[AbstractValue], + debug_info: Optional[DebugInfo] = None, + keep_inputs: Optional[Sequence[bool]] = None, +) -> Tuple[Jaxpr, List[AbstractValue], List[Any]]: with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore main.jaxpr_stack = () # type: ignore with core.new_sublevel(): @@ -2191,6 +2208,7 @@ def trace_to_jaxpr_final(fun: lu.WrappedFun, del fun, main return jaxpr, out_avals, consts + @profiler.annotate_function def trace_to_jaxpr_final2( fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 299729974..4b5b4dd94 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -724,7 +724,9 @@ class ReplicaInfo(NamedTuple): num_global_replicas: int -def find_replicas(jaxpr, axis_size, global_axis_size): +def find_replicas( + jaxpr: core.Jaxpr, axis_size: int, global_axis_size: int +) -> ReplicaInfo: # TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits jaxpr_replicas = dispatch.jaxpr_replicas(jaxpr) num_local_replicas = axis_size * jaxpr_replicas @@ -733,8 +735,8 @@ def find_replicas(jaxpr, axis_size, global_axis_size): def stage_parallel_callable( - pci: ParallelCallableInfo, - fun: lu.WrappedFun): + pci: ParallelCallableInfo, fun: lu.WrappedFun +) -> Tuple[core.Jaxpr, List[Any], ReplicaInfo, ShardInfo]: sharded_avals = tuple( shard_aval(pci.axis_size, axis, aval) if axis is not None else aval for axis, aval in safe_zip(pci.in_axes, pci.avals)) diff --git a/jax/_src/maps.py b/jax/_src/maps.py index e6471d7a1..ce8ce498f 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -17,7 +17,7 @@ import numpy as np import itertools as it from collections import OrderedDict, abc from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set, - NamedTuple, Union, Sequence) + NamedTuple, Union, Sequence, Mapping) from functools import wraps, partial, partialmethod, lru_cache from jax import lax @@ -271,8 +271,8 @@ def xmap(fun: Callable, in_axes, out_axes, *, - axis_sizes: Dict[AxisName, int] = {}, - axis_resources: Dict[AxisName, ResourceSet] = {}, + axis_sizes: Optional[Mapping[AxisName, int]] = None, + axis_resources: Optional[Mapping[AxisName, ResourceSet]] = None, donate_argnums: Union[int, Sequence[int]] = (), backend: Optional[str] = None) -> stages.Wrapped: """Assign a positional signature to a program that uses named array axes. @@ -460,6 +460,9 @@ def xmap(fun: Callable, out_axes, out_axes_entries, out_axes_treedef = _prepare_axes(out_axes, "out_axes") out_axes_entries = tuple(out_axes_entries) # Make entries hashable + axis_sizes = {} if axis_sizes is None else axis_sizes + axis_resources = {} if axis_resources is None else axis_resources + axis_sizes_names = set(axis_sizes.keys()) in_axes_names = set(it.chain(*(spec.keys() for spec in in_axes_entries))) defined_names = axis_sizes_names | in_axes_names diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 15aeab9fa..ad0fb0684 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -14,13 +14,13 @@ from __future__ import annotations import enum -from functools import partial, lru_cache +from functools import partial import inspect import itertools as it import math import operator as op from typing import (Any, Callable, Dict, Hashable, List, Optional, Sequence, - Set, Tuple, TypeVar, Union, Protocol) + Set, Tuple, TypeVar, Union, cast) import numpy as np @@ -358,6 +358,14 @@ def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep, # Staging + +def _check_avals_are_shapedarrays( + avals: Sequence[core.AbstractValue], +) -> Sequence[core.ShapedArray]: + assert all(isinstance(a, core.ShapedArray) for a in avals), avals + return cast(Sequence[core.ShapedArray], avals) + + def _shard_map_staging( trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, in_tracers: Sequence[pe.DynamicJaxprTracer], *, mesh: Mesh, @@ -369,7 +377,8 @@ def _shard_map_staging( in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals) main = trace.main with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()): - jaxpr, out_avals_, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_) + jaxpr, out_avals_generic, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_) + out_avals_ = _check_avals_are_shapedarrays(out_avals_generic) _check_names(out_names_thunk(), out_avals_) if check_rep: in_rep = map(partial(_in_names_to_rep, mesh), in_names)