mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Add more return type annotations.
Fix a new pytype error by adding a checked cast. PiperOrigin-RevId: 523780354
This commit is contained in:
parent
bed81eb013
commit
49e68dbe80
@ -1633,17 +1633,18 @@ class JaxprStackFrame:
|
||||
def add_eqn(self, eqn: core.JaxprEqn):
|
||||
self.eqns.append(eqn)
|
||||
|
||||
def to_jaxpr(self, out_tracers):
|
||||
def to_jaxpr(self, out_tracers: Sequence[Tracer]) -> Tuple[Jaxpr, List[Any]]:
|
||||
# It's not necessary, but we keep the tracer-to-var mapping injective:
|
||||
assert len(self.tracer_to_var) == len(set(self.tracer_to_var.values()))
|
||||
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
|
||||
constvals: Sequence[Any]
|
||||
constvars, constvals = unzip2(self.constvar_to_val.items())
|
||||
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, outvars,
|
||||
self.eqns)
|
||||
jaxpr = Jaxpr(constvars, self.invars, outvars, self.eqns, jaxpr_effects)
|
||||
jaxpr, constvals = _const_folding_and_forwarding(jaxpr, constvals)
|
||||
jaxpr, constvals = _inline_literals(jaxpr, constvals)
|
||||
return jaxpr, constvals
|
||||
return jaxpr, list(constvals)
|
||||
|
||||
def to_jaxpr2(self, out_tracers):
|
||||
# It's not necessary, but we keep the tracer-to-var mapping injective:
|
||||
@ -1684,7 +1685,8 @@ class JaxprStackFrame:
|
||||
const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars]
|
||||
return invar_positions, const_eqns
|
||||
|
||||
def _const_folding_and_forwarding(jaxpr, constvals):
|
||||
def _const_folding_and_forwarding(
|
||||
jaxpr: Jaxpr, constvals: Sequence[Any]) -> Tuple[Jaxpr, Tuple[Any, ...]]:
|
||||
consts: Dict[Var, Any] = dict(zip(jaxpr.constvars, constvals))
|
||||
var_subs: Dict[Var, Var] = {} # not Dict[Var, Atom] b/c literals not inlined
|
||||
new_eqns = []
|
||||
@ -1728,7 +1730,10 @@ ForwardingRule = Callable[[JaxprEqn],
|
||||
Tuple[List[Optional[Var]], Optional[JaxprEqn]]]
|
||||
forwarding_rules: Dict[Primitive, ForwardingRule] = {}
|
||||
|
||||
def _inline_literals(jaxpr, constvals):
|
||||
|
||||
def _inline_literals(
|
||||
jaxpr: Jaxpr, constvals: Sequence[Any]
|
||||
) -> Tuple[Jaxpr, List[Any]]:
|
||||
# This function also prunes unused constants and inserts `dropvar` symbols.
|
||||
input_effects = {eff for eff in jaxpr.effects
|
||||
if isinstance(eff, effects.JaxprInputEffect)}
|
||||
@ -1770,6 +1775,7 @@ def _inline_literals(jaxpr, constvals):
|
||||
jaxpr_effects, jaxpr.debug_info)
|
||||
return new_jaxpr, new_constvals
|
||||
|
||||
|
||||
class DynamicJaxprTrace(core.Trace):
|
||||
__slots__ = [] # type: ignore
|
||||
|
||||
@ -2108,12 +2114,15 @@ def result_info(dbg: DebugInfo) -> Optional[List[KeyPath]]:
|
||||
else:
|
||||
return [path for path, _ in generate_key_paths(dummy_result)]
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
|
||||
in_avals: Sequence[AbstractValue],
|
||||
debug_info: Optional[DebugInfo] = None,
|
||||
*,
|
||||
keep_inputs: Optional[List[bool]] = None):
|
||||
def trace_to_jaxpr_dynamic(
|
||||
fun: lu.WrappedFun,
|
||||
in_avals: Sequence[AbstractValue],
|
||||
debug_info: Optional[DebugInfo] = None,
|
||||
*,
|
||||
keep_inputs: Optional[List[bool]] = None,
|
||||
) -> Tuple[Jaxpr, List[AbstractValue], List[Any]]:
|
||||
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
@ -2121,10 +2130,15 @@ def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
|
||||
del main, fun
|
||||
return jaxpr, out_avals, consts
|
||||
|
||||
def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace,
|
||||
in_avals: Sequence[AbstractValue], *,
|
||||
keep_inputs: Optional[Sequence[bool]] = None,
|
||||
debug_info: Optional[DebugInfo] = None):
|
||||
|
||||
def trace_to_subjaxpr_dynamic(
|
||||
fun: lu.WrappedFun,
|
||||
main: core.MainTrace,
|
||||
in_avals: Sequence[AbstractValue],
|
||||
*,
|
||||
keep_inputs: Optional[Sequence[bool]] = None,
|
||||
debug_info: Optional[DebugInfo] = None,
|
||||
) -> Tuple[Jaxpr, List[AbstractValue], List[Any]]:
|
||||
keep_inputs = [True] * len(in_avals) if keep_inputs is None else keep_inputs
|
||||
|
||||
frame = JaxprStackFrame()
|
||||
@ -2178,11 +2192,14 @@ def extend_jaxpr_stack(main, frame):
|
||||
assert frame is main.jaxpr_stack[-1]
|
||||
main.jaxpr_stack = main.jaxpr_stack[:-1]
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def trace_to_jaxpr_final(fun: lu.WrappedFun,
|
||||
in_avals: Sequence[AbstractValue],
|
||||
debug_info: Optional[DebugInfo] = None,
|
||||
keep_inputs: Optional[Sequence[bool]] = None):
|
||||
def trace_to_jaxpr_final(
|
||||
fun: lu.WrappedFun,
|
||||
in_avals: Sequence[AbstractValue],
|
||||
debug_info: Optional[DebugInfo] = None,
|
||||
keep_inputs: Optional[Sequence[bool]] = None,
|
||||
) -> Tuple[Jaxpr, List[AbstractValue], List[Any]]:
|
||||
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
|
||||
main.jaxpr_stack = () # type: ignore
|
||||
with core.new_sublevel():
|
||||
@ -2191,6 +2208,7 @@ def trace_to_jaxpr_final(fun: lu.WrappedFun,
|
||||
del fun, main
|
||||
return jaxpr, out_avals, consts
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def trace_to_jaxpr_final2(
|
||||
fun: lu.WrappedFun, debug_info: Optional[DebugInfo] = None
|
||||
|
@ -724,7 +724,9 @@ class ReplicaInfo(NamedTuple):
|
||||
num_global_replicas: int
|
||||
|
||||
|
||||
def find_replicas(jaxpr, axis_size, global_axis_size):
|
||||
def find_replicas(
|
||||
jaxpr: core.Jaxpr, axis_size: int, global_axis_size: int
|
||||
) -> ReplicaInfo:
|
||||
# TODO(skyewm): replace this with a chain of pmaps and/or sharded_jits
|
||||
jaxpr_replicas = dispatch.jaxpr_replicas(jaxpr)
|
||||
num_local_replicas = axis_size * jaxpr_replicas
|
||||
@ -733,8 +735,8 @@ def find_replicas(jaxpr, axis_size, global_axis_size):
|
||||
|
||||
|
||||
def stage_parallel_callable(
|
||||
pci: ParallelCallableInfo,
|
||||
fun: lu.WrappedFun):
|
||||
pci: ParallelCallableInfo, fun: lu.WrappedFun
|
||||
) -> Tuple[core.Jaxpr, List[Any], ReplicaInfo, ShardInfo]:
|
||||
sharded_avals = tuple(
|
||||
shard_aval(pci.axis_size, axis, aval) if axis is not None else aval
|
||||
for axis, aval in safe_zip(pci.in_axes, pci.avals))
|
||||
|
@ -17,7 +17,7 @@ import numpy as np
|
||||
import itertools as it
|
||||
from collections import OrderedDict, abc
|
||||
from typing import (Callable, Iterable, Tuple, Optional, Dict, Any, Set,
|
||||
NamedTuple, Union, Sequence)
|
||||
NamedTuple, Union, Sequence, Mapping)
|
||||
from functools import wraps, partial, partialmethod, lru_cache
|
||||
|
||||
from jax import lax
|
||||
@ -271,8 +271,8 @@ def xmap(fun: Callable,
|
||||
in_axes,
|
||||
out_axes,
|
||||
*,
|
||||
axis_sizes: Dict[AxisName, int] = {},
|
||||
axis_resources: Dict[AxisName, ResourceSet] = {},
|
||||
axis_sizes: Optional[Mapping[AxisName, int]] = None,
|
||||
axis_resources: Optional[Mapping[AxisName, ResourceSet]] = None,
|
||||
donate_argnums: Union[int, Sequence[int]] = (),
|
||||
backend: Optional[str] = None) -> stages.Wrapped:
|
||||
"""Assign a positional signature to a program that uses named array axes.
|
||||
@ -460,6 +460,9 @@ def xmap(fun: Callable,
|
||||
out_axes, out_axes_entries, out_axes_treedef = _prepare_axes(out_axes, "out_axes")
|
||||
out_axes_entries = tuple(out_axes_entries) # Make entries hashable
|
||||
|
||||
axis_sizes = {} if axis_sizes is None else axis_sizes
|
||||
axis_resources = {} if axis_resources is None else axis_resources
|
||||
|
||||
axis_sizes_names = set(axis_sizes.keys())
|
||||
in_axes_names = set(it.chain(*(spec.keys() for spec in in_axes_entries)))
|
||||
defined_names = axis_sizes_names | in_axes_names
|
||||
|
@ -14,13 +14,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from functools import partial, lru_cache
|
||||
from functools import partial
|
||||
import inspect
|
||||
import itertools as it
|
||||
import math
|
||||
import operator as op
|
||||
from typing import (Any, Callable, Dict, Hashable, List, Optional, Sequence,
|
||||
Set, Tuple, TypeVar, Union, Protocol)
|
||||
Set, Tuple, TypeVar, Union, cast)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -358,6 +358,14 @@ def process_env_traces(level: int, mesh, in_names, out_names_thunk, check_rep,
|
||||
|
||||
# Staging
|
||||
|
||||
|
||||
def _check_avals_are_shapedarrays(
|
||||
avals: Sequence[core.AbstractValue],
|
||||
) -> Sequence[core.ShapedArray]:
|
||||
assert all(isinstance(a, core.ShapedArray) for a in avals), avals
|
||||
return cast(Sequence[core.ShapedArray], avals)
|
||||
|
||||
|
||||
def _shard_map_staging(
|
||||
trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun,
|
||||
in_tracers: Sequence[pe.DynamicJaxprTracer], *, mesh: Mesh,
|
||||
@ -369,7 +377,8 @@ def _shard_map_staging(
|
||||
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
|
||||
main = trace.main
|
||||
with core.new_sublevel(), core.extend_axis_env_nd(mesh.shape.items()):
|
||||
jaxpr, out_avals_, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
|
||||
jaxpr, out_avals_generic, consts = pe.trace_to_subjaxpr_dynamic(f, main, in_avals_)
|
||||
out_avals_ = _check_avals_are_shapedarrays(out_avals_generic)
|
||||
_check_names(out_names_thunk(), out_avals_)
|
||||
if check_rep:
|
||||
in_rep = map(partial(_in_names_to_rep, mesh), in_names)
|
||||
|
Loading…
x
Reference in New Issue
Block a user