Add more return type annotations.

Fix a new pytype error by adding a checked cast.

PiperOrigin-RevId: 523780354
This commit is contained in:
Peter Hawkins 2023-04-12 12:53:32 -07:00 committed by jax authors
parent bed81eb013
commit 49e68dbe80
4 changed files with 58 additions and 26 deletions

View File

@ -1633,17 +1633,18 @@ class JaxprStackFrame:
def add_eqn(self, eqn: core.JaxprEqn):
self.eqns.append(eqn)
def to_jaxpr(self, out_tracers):
def to_jaxpr(self, out_tracers: Sequence[Tracer]) -> Tuple[Jaxpr, List[Any]]:
# It's not necessary, but we keep the tracer-to-var mapping injective:
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
constvals: Sequence[Any]
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, outvars,
self.eqns)
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects)
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
return jaxpr, constvals
return jaxpr, list(constvals)
def to_jaxpr2(self, out_tracers):
# It's not necessary, but we keep the tracer-to-var mapping injective:
@ -1684,7 +1685,8 @@ class JaxprStackFrame:
const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars]
return invar_positions, const_eqns
def _const_folding_and_forwarding(jaxpr, constvals):
def _const_folding_and_forwarding(
jaxpr: Jaxpr, constvals: Sequence[Any]) -> Tuple[Jaxpr, Tuple[Any, ...]]:
consts: Dict[Var, Any] = dict(zip(jaxpr.constvars, constvals))
var_subs: Dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined
new_eqns = []
@ -1728,7 +1730,10 @@ ForwardingRule = Callable[[JaxprEqn],
Tuple[List[Optional[Var]], Optional[JaxprEqn]]]
forwarding_rules: Dict[Primitive, ForwardingRule] = {}
def _inline_literals(jaxpr, constvals):
def _inline_literals(
jaxpr: Jaxpr, constvals: Sequence[Any]
) -> Tuple[Jaxpr, List[Any]]:
# This function also prunes unused constants and inserts `dropvar` symbols.
input_effects = {eff for eff in jaxpr.effects
if isinstance(eff, effects.JaxprInputEffect)}
@ -1770,6 +1775,7 @@ def _inline_literals(jaxpr, constvals):
jaxpr_effects, jaxpr.debug_info)
return new_jaxpr, new_constvals
class DynamicJaxprTrace(core.Trace):
__slots__ = [] # type: ignore
@ -2108,12 +2114,15 @@ def result_info(dbg: DebugInfo) -> Optional[List[KeyPath]]:
else:
return [path for path, _ in generate_key_paths(dummy_result)]
@profiler.annotate_function
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None,
*,
keep_inputs: Optional[List[bool]] = None):
def trace_to_jaxpr_dynamic(
fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None,
*,
keep_inputs: Optional[List[bool]] = None,
) -> Tuple[Jaxpr, List[AbstractValue], List[Any]]:
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
main.jaxpr_stack = () # type: ignore
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
@ -2121,10 +2130,15 @@ def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
del main, fun
return jaxpr, out_avals, consts
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
in_avals: Sequence[AbstractValue], *,
keep_inputs: Optional[Sequence[bool]] = None,
debug_info: Optional[DebugInfo] = None):
def trace_to_subjaxpr_dynamic(
fun: lu.WrappedFun,
main: core.MainTrace,
in_avals: Sequence[AbstractValue],
*,
keep_inputs: Optional[Sequence[bool]] = None,
debug_info: Optional[DebugInfo] = None,
) -> Tuple[Jaxpr, List[AbstractValue], List[Any]]:
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
frame = JaxprStackFrame()
@ -2178,11 +2192,14 @@ def extend_jaxpr_stack(main, frame):
assert frame is main.jaxpr_stack[-1]
main.jaxpr_stack = main.jaxpr_stack[:-1]
@profiler.annotate_function
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None,
keep_inputs: Optional[Sequence[bool]] = None):
def trace_to_jaxpr_final(
fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: Optional[DebugInfo] = None,
keep_inputs: Optional[Sequence[bool]] = None,
) -> Tuple[Jaxpr, List[AbstractValue], List[Any]]:
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.jaxpr_stack = () # type: ignore
with core.new_sublevel():
@ -2191,6 +2208,7 @@ def trace_to_jaxpr_final(fun: lu.WrappedFun,
del fun, main
return jaxpr, out_avals, consts
@profiler.annotate_function
def trace_to_jaxpr_final2(
fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None

View File

@ -724,7 +724,9 @@ class ReplicaInfo(NamedTuple):
num_global_replicas: int
def find_replicas(jaxpr, axis_size, global_axis_size):
def find_replicas(
jaxpr: core.Jaxpr, axis_size: int, global_axis_size: int
) -> ReplicaInfo:
# TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits
jaxpr_replicas = dispatch.jaxpr_replicas(jaxpr)
num_local_replicas = axis_size * jaxpr_replicas
@ -733,8 +735,8 @@ def find_replicas(jaxpr, axis_size, global_axis_size):
def stage_parallel_callable(
pci: ParallelCallableInfo,
fun: lu.WrappedFun):
pci: ParallelCallableInfo, fun: lu.WrappedFun
) -> Tuple[core.Jaxpr, List[Any], ReplicaInfo, ShardInfo]:
sharded_avals = tuple(
shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
for axis, aval in safe_zip(pci.in_axes, pci.avals))

View File

@ -17,7 +17,7 @@ import numpy as np
import itertools as it
from collections import OrderedDict, abc
from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set,
NamedTuple, Union, Sequence)
NamedTuple, Union, Sequence, Mapping)
from functools import wraps, partial, partialmethod, lru_cache
from jax import lax
@ -271,8 +271,8 @@ def xmap(fun: Callable,
in_axes,
out_axes,
*,
axis_sizes: Dict[AxisName, int] = {},
axis_resources: Dict[AxisName, ResourceSet] = {},
axis_sizes: Optional[Mapping[AxisName, int]] = None,
axis_resources: Optional[Mapping[AxisName, ResourceSet]] = None,
donate_argnums: Union[int, Sequence[int]] = (),
backend: Optional[str] = None) -> stages.Wrapped:
"""Assign a positional signature to a program that uses named array axes.
@ -460,6 +460,9 @@ def xmap(fun: Callable,
out_axes, out_axes_entries, out_axes_treedef = _prepare_axes(out_axes, "out_axes")
out_axes_entries = tuple(out_axes_entries) # Make entries hashable
axis_sizes = {} if axis_sizes is None else axis_sizes
axis_resources = {} if axis_resources is None else axis_resources
axis_sizes_names = set(axis_sizes.keys())
in_axes_names = set(it.chain(*(spec.keys() for spec in in_axes_entries)))
defined_names = axis_sizes_names | in_axes_names

View File

@ -14,13 +14,13 @@
from __future__ import annotations
import enum
from functools import partial, lru_cache
from functools import partial
import inspect
import itertools as it
import math
import operator as op
from typing import (Any, Callable, Dict, Hashable, List, Optional, Sequence,
Set, Tuple, TypeVar, Union, Protocol)
Set, Tuple, TypeVar, Union, cast)
import numpy as np
@ -358,6 +358,14 @@ def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep,
# Staging
def _check_avals_are_shapedarrays(
avals: Sequence[core.AbstractValue],
) -> Sequence[core.ShapedArray]:
assert all(isinstance(a, core.ShapedArray) for a in avals), avals
return cast(Sequence[core.ShapedArray], avals)
def _shard_map_staging(
trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun,
in_tracers: Sequence[pe.DynamicJaxprTracer], *, mesh: Mesh,
@ -369,7 +377,8 @@ def _shard_map_staging(
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
main = trace.main
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, out_avals_, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
jaxpr, out_avals_generic, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
out_avals_ = _check_avals_are_shapedarrays(out_avals_generic)
_check_names(out_names_thunk(), out_avals_)
if check_rep:
in_rep = map(partial(_in_names_to_rep, mesh), in_names)