mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[typing] annotate jax._src.util.safe_map
This commit is contained in:
parent
1a0affddd8
commit
5d15757741
@ -426,6 +426,7 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *,
|
||||
heights: Dict[Tuple[int, ...], int] = {}
|
||||
widths: Dict[Tuple[int, ...], int] = {}
|
||||
for dev, slcs in device_indices_map.items():
|
||||
assert slcs is not None
|
||||
slcs = tuple(map(_raise_to_slice, slcs))
|
||||
chunk_idxs = tuple(map(_slice_to_chunk_idx, shape, slcs))
|
||||
if slcs is None:
|
||||
|
@ -21,8 +21,8 @@ from functools import partial
|
||||
import itertools
|
||||
import time
|
||||
from typing import (
|
||||
Any, Callable, Dict, Optional, Sequence, Set, Tuple, List, Type, Union,
|
||||
Iterator)
|
||||
Any, Callable, Dict, Iterable, Iterator, Optional, Sequence,
|
||||
Set, Tuple, List, Type, Union)
|
||||
from typing_extensions import Protocol
|
||||
import logging
|
||||
import os
|
||||
@ -226,7 +226,7 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
|
||||
donated_invars, inline, keep_unused: bool):
|
||||
del inline # Only used at tracing time
|
||||
if fun.in_type is None:
|
||||
arg_specs = unsafe_map(arg_spec, args)
|
||||
arg_specs: Iterable[Any] = unsafe_map(arg_spec, args)
|
||||
else:
|
||||
# fun.in_type is used for dynamic shapes.
|
||||
if config.jax_array:
|
||||
|
@ -327,7 +327,7 @@ class PmapSharding(XLACompatibleSharding):
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def devices_indices_map(self, global_shape: Shape) -> Mapping[Device, Index]:
|
||||
indices = pxla.spec_to_indices(global_shape, self.sharding_spec)
|
||||
return {d: i for d, i in safe_zip(self.devices.flat, indices)} # type: ignore
|
||||
return dict(safe_zip(self.devices.flat, indices)) # type: ignore[arg-type]
|
||||
|
||||
@pxla.maybe_cached_property
|
||||
def _device_assignment(self) -> XLADeviceAssignment:
|
||||
|
@ -72,8 +72,7 @@ def _unpack_idx(idx: Indexer, ndim: int
|
||||
indexed_dims_ = [type(i) != slice for i in idx]
|
||||
_, non_slice_idx = partition_list(indexed_dims_, idx)
|
||||
indexed_dims = indexed_dims_ + [False] * (ndim - len(indexed_dims_))
|
||||
return (tuple(map(lambda x: jnp.asarray(x, jnp.int32), non_slice_idx)),
|
||||
tuple(indexed_dims))
|
||||
return (tuple(map(jnp.int32, non_slice_idx)), tuple(indexed_dims))
|
||||
|
||||
def _get_slice_output_shape(in_shape: Tuple[int, ...],
|
||||
idx_shapes: Tuple[Tuple[int, ...], ...],
|
||||
|
@ -58,6 +58,21 @@ def safe_zip(*args):
|
||||
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
|
||||
return list(zip(*args))
|
||||
|
||||
# safe_map cannot yet be fully annotated, so we use a strategy similar
|
||||
# to that used for builtins.map in python/typeshed. This supports
|
||||
# checking input types for the callable with up to three arguments.
|
||||
@overload
|
||||
def safe_map(f: Callable[[T1], T], __arg1: Iterable[T1]) -> List[T]: ...
|
||||
|
||||
@overload
|
||||
def safe_map(f: Callable[[T1, T2], T], __arg1: Iterable[T1], __arg2: Iterable[T2]) -> List[T]: ...
|
||||
|
||||
@overload
|
||||
def safe_map(f: Callable[[T1, T2, T3], T], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> List[T]: ...
|
||||
|
||||
@overload
|
||||
def safe_map(f: Callable[..., T], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> List[T]: ...
|
||||
|
||||
def safe_map(f, *args):
|
||||
args = list(map(list, args))
|
||||
n = len(args[0])
|
||||
|
@ -2624,7 +2624,8 @@ def _cond(index: TfVal, *operands: TfVal, branches: Sequence[core.ClosedJaxpr],
|
||||
]
|
||||
if config.jax_experimental_name_stack:
|
||||
# Same name stack as XLA translation of cond_p
|
||||
branches_tf = list(map(source_info_util.extend_name_stack("cond"),
|
||||
# Note: extend_name_stack is a contextmanager, which is callable as a decorator.
|
||||
branches_tf = list(map(source_info_util.extend_name_stack("cond"), # type: ignore[arg-type]
|
||||
branches_tf))
|
||||
return tf.switch_case(index, branches_tf)
|
||||
|
||||
@ -3042,7 +3043,7 @@ def _pjit(*args: TfVal,
|
||||
in_positional_semantics,
|
||||
out_positional_semantics,
|
||||
_in_avals: Sequence[core.ShapedArray],
|
||||
_out_aval: core.ShapedArray) -> TfVal:
|
||||
_out_aval: Sequence[core.ShapedArray]) -> TfVal:
|
||||
del donated_invars
|
||||
if resource_env.physical_mesh.is_multi_process:
|
||||
raise NotImplementedError("jax2tf translation for pjit over multi-process "
|
||||
|
@ -48,7 +48,7 @@ DeviceArray([-1.2655463 , -0.52060574, -0.14522289, -0.10817424,
|
||||
|
||||
import functools
|
||||
from typing import (
|
||||
Any, Callable, Dict, NamedTuple, List, Optional, Sequence, Tuple, Union)
|
||||
Any, Callable, Dict, NamedTuple, List, Optional, Sequence, Tuple)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -353,11 +353,12 @@ def eval_sparse(
|
||||
) -> Sequence[SparsifyValue]:
|
||||
env : Dict[core.Var, SparsifyValue] = {}
|
||||
|
||||
def read(var: core.Var) -> Union[Array, SparsifyValue]:
|
||||
def read(var: core.Atom) -> SparsifyValue:
|
||||
# all literals are dense
|
||||
if isinstance(var, core.Literal):
|
||||
return spenv.dense(var.val)
|
||||
else:
|
||||
assert isinstance(var, core.Var)
|
||||
return env[var]
|
||||
|
||||
def write_buffer(var: core.Var, a: Array) -> None:
|
||||
|
@ -977,13 +977,14 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
Assumes that an MLIR context, location, and insertion point are set.
|
||||
"""
|
||||
assert ctx.platform != "gpu"
|
||||
def read(v: core.Var) -> Sequence[ir.Value]:
|
||||
def read(v: core.Atom) -> Sequence[ir.Value]:
|
||||
if type(v) is core.Literal:
|
||||
return ir_constants(v.val, canonicalize_types=True)
|
||||
else:
|
||||
assert isinstance(v, core.Var)
|
||||
return env[v]
|
||||
|
||||
def aval(v: core.Var) -> core.AbstractValue:
|
||||
def aval(v: core.Atom) -> core.AbstractValue:
|
||||
if type(v) is core.Literal:
|
||||
return xla.abstractify(v.val)
|
||||
else:
|
||||
|
@ -31,7 +31,6 @@ from jax import linear_util as lu
|
||||
from jax._src import api_util
|
||||
from jax._src import dtypes
|
||||
from jax._src import profiler
|
||||
from jax._src.ad_util import Zero
|
||||
from jax._src.api_util import flattened_fun_in_tree, flatten_fun_nokwargs
|
||||
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
|
||||
tree_leaves)
|
||||
@ -379,7 +378,7 @@ class JaxprTrace(Trace):
|
||||
for ax, a in zip(staged_out_axes, out_avals_mapped)]
|
||||
out_tracers = [JaxprTracer(self, PartialVal.unknown(a), None)
|
||||
for a in out_avals]
|
||||
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers),
|
||||
eqn = new_eqn_recipe((*const_tracers, *env_tracers, *unknown_arg_tracers), # type: ignore[arg-type]
|
||||
out_tracers, primitive, staged_params,
|
||||
jaxpr.effects,
|
||||
source_info_util.current())
|
||||
@ -945,7 +944,7 @@ def tracers_to_jaxpr(
|
||||
env_vars, env_vals = unzip2(env.items())
|
||||
const_vars, const_vals = unzip2(consts.items())
|
||||
effects = core.join_effects(*(eqn.effects for eqn in eqns))
|
||||
jaxpr = Jaxpr(const_vars, [*env_vars, *map(get_atom, in_tracers)],
|
||||
jaxpr = Jaxpr(const_vars, [*env_vars, *map(get_atom, in_tracers)], # type: ignore[list-item]
|
||||
map(get_atom, out_tracers), eqns, effects)
|
||||
config.jax_enable_checks and core.check_jaxpr(jaxpr)
|
||||
# del getvar # needed to avoid cyclic-reference closure, apparently!
|
||||
@ -1166,7 +1165,8 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
inputs = map(ensure_instantiated, inst_in, eqn.invars)
|
||||
staged_eqns.append(eqn.replace(invars=inputs))
|
||||
map(partial(write, False, True), eqn.outvars)
|
||||
out_unknowns, out_inst = unsafe_map(list, unzip2(map(read, jaxpr.outvars)))
|
||||
unzipped = unzip2(map(read, jaxpr.outvars))
|
||||
out_unknowns, out_inst = list(unzipped[0]), list(unzipped[1])
|
||||
assert all(type(v) is Var for v in residuals), residuals
|
||||
|
||||
for x, inst, ensure_inst in zip(jaxpr.outvars, out_inst, ensure_out_inst):
|
||||
|
Loading…
x
Reference in New Issue
Block a user