[typing] annotate jax._src.util.safe_map

This commit is contained in:
Jake VanderPlas 2022-10-20 10:15:04 -07:00
parent 1a0affddd8
commit 5d15757741
9 changed files with 34 additions and 16 deletions

View File

@ -426,6 +426,7 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *,
heights: Dict[Tuple[int, ...], int] = {}
widths: Dict[Tuple[int, ...], int] = {}
for dev, slcs in device_indices_map.items():
assert slcs is not None
slcs = tuple(map(_raise_to_slice, slcs))
chunk_idxs = tuple(map(_slice_to_chunk_idx, shape, slcs))
if slcs is None:

View File

@ -21,8 +21,8 @@ from functools import partial
import itertools
import time
from typing import (
Any, Callable, Dict, Optional, Sequence, Set, Tuple, List, Type, Union,
Iterator)
Any, Callable, Dict, Iterable, Iterator, Optional, Sequence,
Set, Tuple, List, Type, Union)
from typing_extensions import Protocol
import logging
import os
@ -226,7 +226,7 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline, keep_unused: bool):
del inline # Only used at tracing time
if fun.in_type is None:
arg_specs = unsafe_map(arg_spec, args)
arg_specs: Iterable[Any] = unsafe_map(arg_spec, args)
else:
# fun.in_type is used for dynamic shapes.
if config.jax_array:

View File

@ -327,7 +327,7 @@ class PmapSharding(XLACompatibleSharding):
@functools.lru_cache(maxsize=4096)
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
return {d: i for d, i in safe_zip(self.devices.flat, indices)} # type: ignore
return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type]
@pxla.maybe_cached_property
def _device_assignment(self) -> XLADeviceAssignment:

View File

@ -72,8 +72,7 @@ def _unpack_idx(idx: Indexer, ndim: int
indexed_dims_ = [type(i) != slice for i in idx]
_, non_slice_idx = partition_list(indexed_dims_, idx)
indexed_dims = indexed_dims_ + [False] * (ndim - len(indexed_dims_))
return (tuple(map(lambda x: jnp.asarray(x, jnp.int32), non_slice_idx)),
tuple(indexed_dims))
return (tuple(map(jnp.int32, non_slice_idx)), tuple(indexed_dims))
def _get_slice_output_shape(in_shape: Tuple[int, ...],
idx_shapes: Tuple[Tuple[int, ...], ...],

View File

@ -58,6 +58,21 @@ def safe_zip(*args):
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
return list(zip(*args))
# safe_map cannot yet be fully annotated, so we use a strategy similar
# to that used for builtins.map in python/typeshed. This supports
# checking input types for the callable with up to three arguments.
@overload
def safe_map(f: Callable[[T1], T], __arg1: Iterable[T1]) -> List[T]: ...
@overload
def safe_map(f: Callable[[T1, T2], T], __arg1: Iterable[T1], __arg2: Iterable[T2]) -> List[T]: ...
@overload
def safe_map(f: Callable[[T1, T2, T3], T], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> List[T]: ...
@overload
def safe_map(f: Callable[..., T], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> List[T]: ...
def safe_map(f, *args):
args = list(map(list, args))
n = len(args[0])

View File

@ -2624,7 +2624,8 @@ def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr],
]
if config.jax_experimental_name_stack:
# Same name stack as XLA translation of cond_p
branches_tf = list(map(source_info_util.extend_name_stack("cond"),
# Note: extend_name_stack is a contextmanager, which is callable as a decorator.
branches_tf = list(map(source_info_util.extend_name_stack("cond"), # type: ignore[arg-type]
branches_tf))
return tf.switch_case(index, branches_tf)
@ -3042,7 +3043,7 @@ def _pjit(*args: TfVal,
in_positional_semantics,
out_positional_semantics,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray) -> TfVal:
_out_aval: Sequence[core.ShapedArray]) -> TfVal:
del donated_invars
if resource_env.physical_mesh.is_multi_process:
raise NotImplementedError("jax2tf translation for pjit over multi-process "

View File

@ -48,7 +48,7 @@ DeviceArray([-1.2655463 , -0.52060574, -0.14522289, -0.10817424,
import functools
from typing import (
Any, Callable, Dict, NamedTuple, List, Optional, Sequence, Tuple, Union)
Any, Callable, Dict, NamedTuple, List, Optional, Sequence, Tuple)
import numpy as np
@ -353,11 +353,12 @@ def eval_sparse(
) -> Sequence[SparsifyValue]:
env : Dict[core.Var, SparsifyValue] = {}
def read(var: core.Var) -> Union[Array, SparsifyValue]:
def read(var: core.Atom) -> SparsifyValue:
# all literals are dense
if isinstance(var, core.Literal):
return spenv.dense(var.val)
else:
assert isinstance(var, core.Var)
return env[var]
def write_buffer(var: core.Var, a: Array) -> None:

View File

@ -977,13 +977,14 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
Assumes that an MLIR context, location, and insertion point are set.
"""
assert ctx.platform != "gpu"
def read(v: core.Var) -> Sequence[ir.Value]:
def read(v: core.Atom) -> Sequence[ir.Value]:
if type(v) is core.Literal:
return ir_constants(v.val, canonicalize_types=True)
else:
assert isinstance(v, core.Var)
return env[v]
def aval(v: core.Var) -> core.AbstractValue:
def aval(v: core.Atom) -> core.AbstractValue:
if type(v) is core.Literal:
return xla.abstractify(v.val)
else:

View File

@ -31,7 +31,6 @@ from jax import linear_util as lu
from jax._src import api_util
from jax._src import dtypes
from jax._src import profiler
from jax._src.ad_util import Zero
from jax._src.api_util import flattened_fun_in_tree, flatten_fun_nokwargs
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
tree_leaves)
@ -379,7 +378,7 @@ class JaxprTrace(Trace):
for ax, a in zip(staged_out_axes, out_avals_mapped)]
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
for a in out_avals]
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), # type: ignore[arg-type]
out_tracers, primitive, staged_params,
jaxpr.effects,
source_info_util.current())
@ -945,7 +944,7 @@ def tracers_to_jaxpr(
env_vars, env_vals = unzip2(env.items())
const_vars, const_vals = unzip2(consts.items())
effects = core.join_effects(*(eqn.effects for eqn in eqns))
jaxpr = Jaxpr(const_vars, [*env_vars, *map(get_atom, in_tracers)],
jaxpr = Jaxpr(const_vars, [*env_vars, *map(get_atom, in_tracers)], # type: ignore[list-item]
map(get_atom, out_tracers), eqns, effects)
config.jax_enable_checks and core.check_jaxpr(jaxpr)
# del getvar # needed to avoid cyclic-reference closure, apparently!
@ -1166,7 +1165,8 @@ def _partial_eval_jaxpr_custom_cached(
inputs = map(ensure_instantiated, inst_in, eqn.invars)
staged_eqns.append(eqn.replace(invars=inputs))
map(partial(write, False, True), eqn.outvars)
out_unknowns, out_inst = unsafe_map(list, unzip2(map(read, jaxpr.outvars)))
unzipped = unzip2(map(read, jaxpr.outvars))
out_unknowns, out_inst = list(unzipped[0]), list(unzipped[1])
assert all(type(v) is Var for v in residuals), residuals
for x, inst, ensure_inst in zip(jaxpr.outvars, out_inst, ensure_out_inst):