mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add a variant of safe_map() that has no return value, named foreach().
This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results. I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin. PiperOrigin-RevId: 736859418
This commit is contained in:
parent
074216e07a
commit
8ab33669e2
@ -51,7 +51,7 @@ from jax._src.tree_util import tree_map
|
||||
from jax._src.tree_util import tree_unflatten
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
|
||||
unzip3, weakref_lru_cache, HashableWrapper)
|
||||
unzip3, weakref_lru_cache, HashableWrapper, foreach)
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
traceback_util.register_exclusion(__file__)
|
||||
@ -413,8 +413,8 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
|
||||
def write_env(var: core.Var, val: Any):
|
||||
env[var] = val
|
||||
|
||||
map(write_env, jaxpr.constvars, consts)
|
||||
map(write_env, jaxpr.invars, in_args)
|
||||
foreach(write_env, jaxpr.constvars, consts)
|
||||
foreach(write_env, jaxpr.invars, in_args)
|
||||
|
||||
# interpreter loop
|
||||
for eqn in jaxpr.eqns:
|
||||
@ -427,7 +427,7 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
|
||||
error, outvals = checkify_rule(error, enabled_errors,
|
||||
*invals, **eqn.params)
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write_env, eqn.outvars, outvals)
|
||||
foreach(write_env, eqn.outvars, outvals)
|
||||
else:
|
||||
write_env(eqn.outvars[0], outvals)
|
||||
core.clean_up_dead_vars(eqn, env, last_used)
|
||||
|
@ -50,7 +50,7 @@ from jax._src import source_info_util
|
||||
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
|
||||
tuple_delete, cache,
|
||||
HashableFunction, HashableWrapper, weakref_lru_cache,
|
||||
partition_list, StrictABCMeta)
|
||||
partition_list, StrictABCMeta, foreach)
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src.named_sharding import NamedSharding
|
||||
from jax._src.lib import jax_jit
|
||||
@ -578,8 +578,8 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True) -> list[
|
||||
env[v] = val
|
||||
|
||||
env: dict[Var, Any] = {}
|
||||
map(write, jaxpr.constvars, consts)
|
||||
map(write, jaxpr.invars, args)
|
||||
foreach(write, jaxpr.constvars, consts)
|
||||
foreach(write, jaxpr.invars, args)
|
||||
lu = last_used(jaxpr)
|
||||
for eqn in jaxpr.eqns:
|
||||
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
|
||||
@ -589,7 +589,7 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True) -> list[
|
||||
traceback, name_stack=name_stack), eqn.ctx.manager:
|
||||
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write, eqn.outvars, ans)
|
||||
foreach(write, eqn.outvars, ans)
|
||||
else:
|
||||
write(eqn.outvars[0], ans)
|
||||
clean_up_dead_vars(eqn, env, lu)
|
||||
@ -2837,7 +2837,7 @@ def _check_jaxpr(
|
||||
|
||||
# Check out_type matches the let-binders' annotation (after substitution).
|
||||
out_type = substitute_vars_in_output_ty(out_type, eqn.invars, eqn.outvars)
|
||||
map(write, eqn.outvars, out_type)
|
||||
foreach(write, eqn.outvars, out_type)
|
||||
|
||||
except JaxprTypeError as e:
|
||||
ctx, settings = ctx_factory()
|
||||
@ -2848,7 +2848,7 @@ def _check_jaxpr(
|
||||
raise JaxprTypeError(msg, eqn_idx) from None
|
||||
|
||||
# TODO(mattjj): include output type annotation on jaxpr and check it here
|
||||
map(read, jaxpr.outvars)
|
||||
foreach(read, jaxpr.outvars)
|
||||
|
||||
def check_type(
|
||||
ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]],
|
||||
|
@ -39,7 +39,7 @@ from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal)
|
||||
from jax._src.dtypes import dtype, float0
|
||||
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
|
||||
as_hashable_function, weakref_lru_cache,
|
||||
partition_list, subs_list2)
|
||||
partition_list, subs_list2, foreach)
|
||||
|
||||
zip = safe_zip
|
||||
map = safe_map
|
||||
@ -344,10 +344,10 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
|
||||
primal_env[v] = val
|
||||
|
||||
primal_env: dict[Any, Any] = {}
|
||||
map(write_primal, jaxpr.constvars, consts)
|
||||
foreach(write_primal, jaxpr.constvars, consts)
|
||||
# FIXME: invars can contain both primal and tangent values, and this line
|
||||
# forces primal_in to contain UndefinedPrimals for tangent values!
|
||||
map(write_primal, jaxpr.invars, primals_in)
|
||||
foreach(write_primal, jaxpr.invars, primals_in)
|
||||
|
||||
# Start with a forward pass to evaluate any side-effect-free JaxprEqns that
|
||||
# only operate on primals. This is required to support primitives with
|
||||
@ -367,7 +367,7 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
|
||||
traceback, name_stack=name_stack), eqn.ctx.manager:
|
||||
ans = eqn.primitive.bind(*subfuns, *map(read_primal, eqn.invars), **bind_params)
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write_primal, eqn.outvars, ans)
|
||||
foreach(write_primal, eqn.outvars, ans)
|
||||
else:
|
||||
write_primal(eqn.outvars[0], ans)
|
||||
|
||||
@ -375,7 +375,7 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
|
||||
ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
|
||||
else contextlib.nullcontext())
|
||||
with ctx:
|
||||
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
|
||||
foreach(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
|
||||
for eqn in lin_eqns[::-1]:
|
||||
if eqn.primitive.ref_primitive:
|
||||
if eqn.primitive is core.mutable_array_p:
|
||||
@ -417,7 +417,7 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
|
||||
raise e from None
|
||||
cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
|
||||
# FIXME: Some invars correspond to primals!
|
||||
map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
|
||||
foreach(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
|
||||
|
||||
cotangents_out = map(read_cotangent, jaxpr.invars)
|
||||
return cotangents_out
|
||||
|
@ -53,6 +53,7 @@ from jax._src.sharding import Sharding as JSharding
|
||||
from jax._src.sharding_impls import (AUTO, NamedSharding,
|
||||
modify_sdy_sharding_wrt_axis_types,
|
||||
SdyArraySharding, SdyArrayShardingList)
|
||||
from jax._src.util import foreach
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension, xla_extension_version
|
||||
from jax._src.lib.mlir import dialects, ir, passmanager
|
||||
@ -1941,8 +1942,8 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
assert len(args) == len(jaxpr.invars), (jaxpr, args)
|
||||
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
|
||||
assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values)
|
||||
map(write, jaxpr.constvars, consts)
|
||||
map(write, jaxpr.invars, args)
|
||||
foreach(write, jaxpr.constvars, consts)
|
||||
foreach(write, jaxpr.invars, args)
|
||||
last_used = core.last_used(jaxpr)
|
||||
for eqn in jaxpr.eqns:
|
||||
in_nodes = map(read, eqn.invars)
|
||||
@ -2009,7 +2010,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
f"{eqn}, got output {ans}") from e
|
||||
|
||||
assert len(ans) == len(eqn.outvars), (ans, eqn)
|
||||
map(write, eqn.outvars, out_nodes)
|
||||
foreach(write, eqn.outvars, out_nodes)
|
||||
core.clean_up_dead_vars(eqn, env, last_used)
|
||||
return tuple(read(v) for v in jaxpr.outvars), tokens
|
||||
|
||||
@ -2101,11 +2102,11 @@ def lower_per_platform(ctx: LoweringRuleContext,
|
||||
# If there is a single rule left just apply the rule, without conditionals.
|
||||
if len(kept_rules) == 1:
|
||||
output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
|
||||
map(
|
||||
foreach(
|
||||
lambda o: wrap_compute_type_in_place(ctx, o.owner),
|
||||
filter(_is_not_block_argument, flatten_ir_values(output)),
|
||||
)
|
||||
map(
|
||||
foreach(
|
||||
lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
|
||||
flatten_ir_values(output),
|
||||
)
|
||||
@ -2146,11 +2147,11 @@ def lower_per_platform(ctx: LoweringRuleContext,
|
||||
except TypeError as e:
|
||||
raise ValueError("Output of translation rule must be iterable: "
|
||||
f"{description}, got output {output}") from e
|
||||
map(
|
||||
foreach(
|
||||
lambda o: wrap_compute_type_in_place(ctx, o.owner),
|
||||
filter(_is_not_block_argument, out_nodes),
|
||||
)
|
||||
map(
|
||||
foreach(
|
||||
lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
|
||||
out_nodes,
|
||||
)
|
||||
|
@ -47,7 +47,7 @@ from jax._src.tree_util import (PyTreeDef, treedef_tuple,
|
||||
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
merge_lists, partition_list, OrderedSet,
|
||||
as_hashable_function, weakref_lru_cache, subs_list,
|
||||
HashableFunction)
|
||||
HashableFunction, foreach)
|
||||
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
@ -1085,8 +1085,8 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
|
||||
newvar = core.gensym(suffix='_offload')
|
||||
known_eqns, staged_eqns = [], []
|
||||
map(write, in_unknowns, in_inst, jaxpr.invars)
|
||||
map(partial(write, False, True), jaxpr.constvars)
|
||||
foreach(write, in_unknowns, in_inst, jaxpr.invars)
|
||||
foreach(partial(write, False, True), jaxpr.constvars)
|
||||
for eqn in jaxpr.eqns:
|
||||
unks_in, inst_in = unzip2(map(read, eqn.invars))
|
||||
rule = partial_eval_jaxpr_custom_rules.get(eqn.primitive)
|
||||
@ -1098,18 +1098,18 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
residual_refs.add(r)
|
||||
else:
|
||||
residuals.add(r)
|
||||
map(write, unks_out, inst_out, eqn.outvars)
|
||||
foreach(write, unks_out, inst_out, eqn.outvars)
|
||||
elif any(unks_in):
|
||||
inputs = map(ensure_instantiated, inst_in, eqn.invars)
|
||||
staged_eqns.append(eqn.replace(invars=inputs))
|
||||
map(partial(write, True, True), eqn.outvars)
|
||||
foreach(partial(write, True, True), eqn.outvars)
|
||||
else:
|
||||
known_eqns.append(eqn)
|
||||
# If it's an effectful primitive, we always to run and avoid staging it.
|
||||
policy = ensure_enum(saveable(
|
||||
eqn.primitive, *[x.aval for x in eqn.invars], **eqn.params))
|
||||
if has_effects(eqn.effects) or isinstance(policy, SaveableType):
|
||||
map(partial(write, False, False), eqn.outvars)
|
||||
foreach(partial(write, False, False), eqn.outvars)
|
||||
elif isinstance(policy, Offloadable):
|
||||
# TODO(slebedev): This is a legit error which requires a BUILD fix.
|
||||
from jax._src.dispatch import device_put_p, TransferToMemoryKind, CopySemantics # pytype: disable=import-error
|
||||
@ -1124,7 +1124,7 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
JaxprEqnContext(None, False))
|
||||
known_eqns.append(offload_eqn)
|
||||
# resvars are known and available in the backward jaxpr.
|
||||
map(partial(write, False, True), resvars)
|
||||
foreach(partial(write, False, True), resvars)
|
||||
residuals.update(resvars)
|
||||
reload_eqn = core.JaxprEqn(
|
||||
resvars, eqn.outvars, device_put_p,
|
||||
@ -1135,12 +1135,12 @@ def _partial_eval_jaxpr_custom_cached(
|
||||
JaxprEqnContext(None, False))
|
||||
staged_eqns.append(reload_eqn)
|
||||
# outvars are known and available in the backward jaxpr.
|
||||
map(partial(write, False, True), eqn.outvars)
|
||||
foreach(partial(write, False, True), eqn.outvars)
|
||||
else:
|
||||
assert isinstance(policy, RecomputeType)
|
||||
inputs = map(ensure_instantiated, inst_in, eqn.invars)
|
||||
staged_eqns.append(eqn.replace(invars=inputs))
|
||||
map(partial(write, False, True), eqn.outvars)
|
||||
foreach(partial(write, False, True), eqn.outvars)
|
||||
unzipped = unzip2(map(read, jaxpr.outvars))
|
||||
out_unknowns, out_inst = list(unzipped[0]), list(unzipped[1])
|
||||
assert all(type(v) is Var for v in residuals), residuals
|
||||
@ -1441,14 +1441,14 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
|
||||
env[x] = read(x) or b
|
||||
|
||||
new_eqns = []
|
||||
map(write, jaxpr.outvars, used_outputs)
|
||||
foreach(write, jaxpr.outvars, used_outputs)
|
||||
for eqn in jaxpr.eqns[::-1]:
|
||||
used_outs = map(read, eqn.outvars)
|
||||
rule = dce_rules.get(eqn.primitive, _default_dce_rule)
|
||||
used_ins, new_eqn = rule(used_outs, eqn)
|
||||
if new_eqn is not None:
|
||||
new_eqns.append(new_eqn)
|
||||
map(write, eqn.invars, used_ins)
|
||||
foreach(write, eqn.invars, used_ins)
|
||||
used_inputs = map(read, jaxpr.invars)
|
||||
used_inputs = map(op.or_, instantiate, used_inputs)
|
||||
|
||||
@ -2495,15 +2495,15 @@ def _eval_jaxpr_padded(
|
||||
def write(v, val) -> None:
|
||||
env[v] = val
|
||||
|
||||
map(write, jaxpr.constvars, consts)
|
||||
map(write, jaxpr.invars, args)
|
||||
foreach(write, jaxpr.constvars, consts)
|
||||
foreach(write, jaxpr.invars, args)
|
||||
last_used = core.last_used(jaxpr)
|
||||
for eqn in jaxpr.eqns:
|
||||
in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars]
|
||||
out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars]
|
||||
rule = padding_rules[eqn.primitive]
|
||||
outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params)
|
||||
map(write, eqn.outvars, outs)
|
||||
foreach(write, eqn.outvars, outs)
|
||||
core.clean_up_dead_vars(eqn, env, last_used)
|
||||
return map(read, jaxpr.outvars)
|
||||
|
||||
@ -2580,7 +2580,7 @@ def inline_jaxpr_into_trace(
|
||||
src_ = (src if not eqn.source_info.name_stack else
|
||||
src.replace(name_stack=src.name_stack + eqn.source_info.name_stack))
|
||||
trace.frame.add_eqn(eqn.replace(invars, outvars, source_info=src_))
|
||||
map(env.setdefault, eqn.outvars, outvars)
|
||||
foreach(env.setdefault, eqn.outvars, outvars)
|
||||
|
||||
tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars],
|
||||
[*consts, *arg_tracers]))
|
||||
|
@ -71,7 +71,8 @@ from jax._src.sharding_impls import (PmapSharding, NamedSharding,
|
||||
PartitionSpec as P, canonicalize_sharding)
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
|
||||
from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis,
|
||||
safe_map, safe_zip, split_list, weakref_lru_cache)
|
||||
safe_map, safe_zip, split_list, weakref_lru_cache,
|
||||
foreach)
|
||||
|
||||
_max = builtins.max
|
||||
_min = builtins.min
|
||||
@ -106,7 +107,7 @@ def _validate_shapes(shapes: Sequence[Shape]):
|
||||
# pass dynamic shapes through unchecked
|
||||
return
|
||||
else:
|
||||
map(_check_static_shape, shapes)
|
||||
foreach(_check_static_shape, shapes)
|
||||
|
||||
def _try_broadcast_shapes(*shapes: tuple[int, ...], name: str) -> tuple[int, ...]:
|
||||
"""
|
||||
|
@ -37,6 +37,7 @@ from jax._src.pallas.fuser import block_spec
|
||||
from jax._src.pallas.fuser.fusable import fusable_p
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import primitives as state_primitives
|
||||
from jax._src.util import foreach
|
||||
|
||||
# TODO(sharadmv): Enable type checking.
|
||||
# mypy: ignore-errors
|
||||
@ -216,11 +217,11 @@ def physicalize_interp(
|
||||
def write_env(var: core.Var, val: Any):
|
||||
env[var] = val
|
||||
|
||||
map(write_env, jaxpr.constvars, consts)
|
||||
foreach(write_env, jaxpr.constvars, consts)
|
||||
assert len(jaxpr.invars) == len(
|
||||
args
|
||||
), f"Length mismatch: {jaxpr.invars} != {args}"
|
||||
map(write_env, jaxpr.invars, args)
|
||||
foreach(write_env, jaxpr.invars, args)
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
invals = list(map(read_env, eqn.invars))
|
||||
@ -248,7 +249,7 @@ def physicalize_interp(
|
||||
|
||||
if eqn.primitive.multiple_results:
|
||||
assert len(outvals) == len(eqn.outvars)
|
||||
map(write_env, eqn.outvars, outvals)
|
||||
foreach(write_env, eqn.outvars, outvals)
|
||||
else:
|
||||
write_env(eqn.outvars[0], outvals)
|
||||
|
||||
|
@ -40,6 +40,7 @@ from jax._src.pallas import primitives
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src import util
|
||||
from jax._src.util import (
|
||||
foreach,
|
||||
safe_map,
|
||||
safe_zip,
|
||||
split_list,
|
||||
@ -199,8 +200,8 @@ def eval_jaxpr_recursive(
|
||||
env[v] = val
|
||||
|
||||
env: dict[jax_core.Var, Any] = {}
|
||||
map(write, jaxpr.constvars, consts)
|
||||
map(write, jaxpr.invars, args)
|
||||
foreach(write, jaxpr.constvars, consts)
|
||||
foreach(write, jaxpr.invars, args)
|
||||
lu = jax_core.last_used(jaxpr)
|
||||
for eqn in jaxpr.eqns:
|
||||
in_vals = map(read, eqn.invars)
|
||||
@ -216,7 +217,7 @@ def eval_jaxpr_recursive(
|
||||
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
|
||||
ans = eqn.primitive.bind(*subfuns, *in_vals, **bind_params)
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write, eqn.outvars, ans)
|
||||
foreach(write, eqn.outvars, ans)
|
||||
else:
|
||||
write(eqn.outvars[0], ans)
|
||||
jax_core.clean_up_dead_vars(eqn, env, lu)
|
||||
|
@ -65,6 +65,7 @@ from jax._src.state import primitives as state_primitives
|
||||
from jax._src.state.types import RefBitcaster, RefReshaper
|
||||
from jax._src.state.utils import dtype_bitwidth
|
||||
from jax._src.typing import Array, DTypeLike
|
||||
from jax._src.util import foreach
|
||||
from jax._src.util import safe_map
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.util import split_list
|
||||
@ -950,7 +951,7 @@ def jaxpr_subcomp(
|
||||
|
||||
for invar, bs in zip(jaxpr.invars, ctx.block_shapes):
|
||||
block_shape_env[invar] = bs
|
||||
map(write_env, jaxpr.invars, args)
|
||||
foreach(write_env, jaxpr.invars, args)
|
||||
|
||||
initial_name_stack = [scope.name for scope in ctx.name_stack.stack]
|
||||
current_name_stack: list[str] = []
|
||||
@ -1011,7 +1012,7 @@ def jaxpr_subcomp(
|
||||
f"{eqn.primitive.name}. "
|
||||
"Please file an issue on https://github.com/jax-ml/jax/issues.")
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write_env, eqn.outvars, ans)
|
||||
foreach(write_env, eqn.outvars, ans)
|
||||
else:
|
||||
write_env(eqn.outvars[0], ans)
|
||||
|
||||
|
@ -52,6 +52,7 @@ from jax._src.state import indexing
|
||||
from jax._src.state import primitives as sp
|
||||
from jax._src.state import types as state_types
|
||||
from jax._src.state.types import RefReshaper
|
||||
from jax._src.util import foreach
|
||||
import jax.experimental.mosaic.gpu as mgpu
|
||||
from jax.experimental.mosaic.gpu import core as mgpu_core
|
||||
from jax.experimental.mosaic.gpu import profiler as mgpu_profiler
|
||||
@ -738,8 +739,8 @@ def lower_jaxpr_to_mosaic_gpu(
|
||||
if val.type != mlir_dtype:
|
||||
raise AssertionError(f"Scalar type must match ShapedArray dtype, got: {val.type} != {mlir_dtype}")
|
||||
|
||||
map(write_env, jaxpr.constvars, consts)
|
||||
map(lambda v, a: write_env(v, a, require_value=False), jaxpr.invars, args)
|
||||
foreach(write_env, jaxpr.constvars, consts)
|
||||
foreach(lambda v, a: write_env(v, a, require_value=False), jaxpr.invars, args)
|
||||
# TODO(justinfu): Handle transform scopes.
|
||||
last_local_name_stack: list[str] = []
|
||||
named_regions = []
|
||||
@ -786,7 +787,7 @@ def lower_jaxpr_to_mosaic_gpu(
|
||||
f" {rule_ctx}\nWith inval types={inval_types}\nIn jaxpr:\n{jaxpr}"
|
||||
) from e
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write_env, eqn.outvars, outvals)
|
||||
foreach(write_env, eqn.outvars, outvals)
|
||||
else:
|
||||
write_env(eqn.outvars[0], outvals)
|
||||
while named_regions: # Drain the name stack.
|
||||
|
@ -33,6 +33,7 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.pallas import core as pallas_core
|
||||
from jax._src.pallas.mosaic_gpu import core as gpu_core
|
||||
from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives
|
||||
from jax._src.util import foreach
|
||||
from jax.experimental import pallas as pl
|
||||
import jax.numpy as jnp
|
||||
|
||||
@ -247,7 +248,7 @@ def emit_pipeline(
|
||||
it.islice(it.product(*map(range, grid)), max_concurrent_steps)
|
||||
):
|
||||
indices = tuple(map(lambda i: jnp.asarray(i, dtype=jnp.int32), indices))
|
||||
map(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs)
|
||||
foreach(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs)
|
||||
|
||||
# This is true if any of the outputs need to be transferred inside the loop.
|
||||
copies_out_in_loop = not all(bref.is_index_invariant for bref in out_brefs)
|
||||
|
@ -44,6 +44,7 @@ from jax._src.pallas import primitives
|
||||
from jax._src.state import discharge as state_discharge
|
||||
from jax._src.state import types as state_types
|
||||
from jax._src.util import (
|
||||
foreach,
|
||||
safe_map,
|
||||
safe_zip,
|
||||
split_list,
|
||||
@ -978,7 +979,7 @@ def pallas_call_checkify_oob_grid(error: checkify.Error,
|
||||
for bm in grid_mapping.block_mappings]
|
||||
# We perform a dynamic slice on the i/o blocks, which will be checked by
|
||||
# checkify for OOB accesses.
|
||||
map(hlo_interpreter._dynamic_slice, start_indices, block_shapes,
|
||||
foreach(hlo_interpreter._dynamic_slice, start_indices, block_shapes,
|
||||
[*input_args, *output_args], is_indexing_dim)
|
||||
return (i + 1, hlo_interpreter._get_next_indices(grid, loop_idx))
|
||||
def f(_):
|
||||
|
@ -53,6 +53,7 @@ from jax._src.pallas import utils as pallas_utils
|
||||
from jax._src.state import discharge
|
||||
from jax._src.state import indexing
|
||||
from jax._src.state import primitives as sp
|
||||
from jax._src.util import foreach
|
||||
from jax._src.util import merge_lists
|
||||
from jax._src.util import partition_list
|
||||
from jax._src.util import split_list
|
||||
@ -380,7 +381,7 @@ def lower_jaxpr_to_triton_ir(
|
||||
if block_info is not None:
|
||||
block_info_env[invar] = block_info
|
||||
|
||||
map(write_env, jaxpr.invars, args)
|
||||
foreach(write_env, jaxpr.invars, args)
|
||||
|
||||
for eqn in jaxpr.eqns:
|
||||
invals = map(read_env, eqn.invars)
|
||||
@ -410,7 +411,7 @@ def lower_jaxpr_to_triton_ir(
|
||||
f"msg={e}"
|
||||
) from e
|
||||
if eqn.primitive.multiple_results:
|
||||
map(write_env, eqn.outvars, outvals)
|
||||
foreach(write_env, eqn.outvars, outvals)
|
||||
else:
|
||||
write_env(eqn.outvars[0], outvals)
|
||||
|
||||
|
@ -45,6 +45,7 @@ from jax._src.state.types import (
|
||||
from jax._src.state.utils import bitcast, hoist_consts_to_refs
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (
|
||||
foreach,
|
||||
merge_lists,
|
||||
partition_list,
|
||||
safe_map,
|
||||
@ -140,10 +141,10 @@ def _eval_jaxpr_discharge_state(
|
||||
*args: Any):
|
||||
env = Environment({})
|
||||
|
||||
map(env.write, jaxpr.constvars, consts)
|
||||
foreach(env.write, jaxpr.constvars, consts)
|
||||
# Here some args may correspond to `Ref` avals but they'll be treated like
|
||||
# regular values in this interpreter.
|
||||
map(env.write, jaxpr.invars, args)
|
||||
foreach(env.write, jaxpr.invars, args)
|
||||
|
||||
refs_to_discharge = {id(v.aval) for v, d in zip(jaxpr.invars, should_discharge)
|
||||
if d and isinstance(v.aval, AbstractRef)}
|
||||
@ -195,7 +196,7 @@ def _eval_jaxpr_discharge_state(
|
||||
ans = eqn.primitive.bind(*subfuns, *map(env.read, eqn.invars),
|
||||
**bind_params)
|
||||
if eqn.primitive.multiple_results:
|
||||
map(env.write, eqn.outvars, ans)
|
||||
foreach(env.write, eqn.outvars, ans)
|
||||
else:
|
||||
env.write(eqn.outvars[0], ans)
|
||||
# By convention, we return the outputs of the jaxpr first and then the final
|
||||
|
@ -90,6 +90,31 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
safe_map = jaxlib_utils.safe_map
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@overload
|
||||
def foreach(f: Callable[[T1], Any], __arg1: Iterable[T1]) -> None: ...
|
||||
|
||||
@overload
|
||||
def foreach(f: Callable[[T1, T2], Any], __arg1: Iterable[T1], __arg2: Iterable[T2]) -> None: ...
|
||||
|
||||
@overload
|
||||
def foreach(f: Callable[[T1, T2, T3], Any], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> None: ...
|
||||
|
||||
@overload
|
||||
def foreach(f: Callable[..., Any], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> None: ...
|
||||
|
||||
def foreach(f, *args):
|
||||
safe_map(f, *args)
|
||||
return None
|
||||
|
||||
else:
|
||||
# TODO(phawkins): remove after jaxlib 0.5.2 is the minimum.
|
||||
if hasattr(jaxlib_utils, 'foreach'):
|
||||
foreach = jaxlib_utils.foreach
|
||||
else:
|
||||
foreach = safe_map
|
||||
|
||||
|
||||
def unzip2(xys: Iterable[tuple[T1, T2]]
|
||||
) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
|
||||
"""Unzip sequence of length-2 tuples into two tuples."""
|
||||
|
@ -28,6 +28,7 @@ from jax._src.api import make_jaxpr
|
||||
from jax._src.interpreters.partial_eval import dce_jaxpr
|
||||
from jax._src.mesh import AbstractMesh, Mesh
|
||||
from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map
|
||||
from jax._src.util import foreach
|
||||
from jax.experimental import shard_map
|
||||
|
||||
|
||||
@ -165,12 +166,12 @@ def _roofline_interpreter(
|
||||
|
||||
jaxpr = jaxpr.jaxpr if isinstance(jaxpr, core.ClosedJaxpr) else jaxpr
|
||||
make_roofline_shape = lambda x: RooflineShape.from_aval(aval(x))
|
||||
map(
|
||||
foreach(
|
||||
write,
|
||||
jaxpr.constvars,
|
||||
map(make_roofline_shape, jaxpr.constvars),
|
||||
)
|
||||
map(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars))
|
||||
foreach(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars))
|
||||
last_used = core.last_used(jaxpr)
|
||||
for eqn in jaxpr.eqns:
|
||||
source_info = eqn.source_info.replace(
|
||||
@ -210,7 +211,7 @@ def _roofline_interpreter(
|
||||
**eqn.params,
|
||||
)
|
||||
|
||||
map(write, eqn.outvars, map(make_roofline_shape, eqn.outvars))
|
||||
foreach(write, eqn.outvars, map(make_roofline_shape, eqn.outvars))
|
||||
core.clean_up_dead_vars(eqn, env, last_used)
|
||||
result += RooflineResult(peak_hbm_bytes=calculate_peak_hbm_bytes())
|
||||
|
||||
|
@ -58,7 +58,7 @@ from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import sdy
|
||||
from jax._src.util import (HashableFunction, HashablePartial, unzip2,
|
||||
as_hashable_function, memoize, partition_list,
|
||||
split_list, subs_list2)
|
||||
split_list, subs_list2, foreach)
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
@ -648,15 +648,15 @@ def _check_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[RepType]
|
||||
def write(v: core.Var, val: RepType) -> None:
|
||||
env[v] = val
|
||||
|
||||
map(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars))
|
||||
map(write, jaxpr.invars, in_rep)
|
||||
foreach(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars))
|
||||
foreach(write, jaxpr.invars, in_rep)
|
||||
last_used = core.last_used(jaxpr)
|
||||
for e in jaxpr.eqns:
|
||||
rule = _check_rules.get(e.primitive, partial(_rule_missing, e.primitive))
|
||||
out_rep = rule(mesh, *map(read, e.invars), **e.params)
|
||||
if e.primitive.multiple_results:
|
||||
out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep
|
||||
map(write, e.outvars, out_rep)
|
||||
foreach(write, e.outvars, out_rep)
|
||||
else:
|
||||
write(e.outvars[0], out_rep)
|
||||
core.clean_up_dead_vars(e, env, last_used)
|
||||
|
@ -26,6 +26,7 @@ from jax import lax
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import util
|
||||
from jax._src.util import foreach
|
||||
|
||||
import jax.experimental.slab.slab as sl
|
||||
|
||||
@ -107,11 +108,11 @@ def eval_djaxpr(jaxpr: core.Jaxpr, slab: sl.Slab, *args: jax.Array | sl.SlabView
|
||||
def write(v, val):
|
||||
env[v] = val
|
||||
|
||||
map(write, jaxpr.invars, args)
|
||||
foreach(write, jaxpr.invars, args)
|
||||
for eqn in jaxpr.eqns:
|
||||
invals = map(read, eqn.invars)
|
||||
slab, outvals = rules[eqn.primitive](slab, *invals, **eqn.params)
|
||||
map(write, eqn.outvars, outvals)
|
||||
foreach(write, eqn.outvars, outvals)
|
||||
return slab, map(read, jaxpr.outvars)
|
||||
|
||||
rules: dict[core.Primitive, Callable] = {}
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "absl/cleanup/cleanup.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
@ -125,6 +127,79 @@ PyMethodDef safe_map_def = {
|
||||
METH_FASTCALL,
|
||||
};
|
||||
|
||||
// Similar to SafeMap, but ignores the return values of the function and returns
|
||||
// None.
|
||||
PyObject* ForEach(PyObject* self, PyObject* const* args, Py_ssize_t nargs) {
|
||||
if (nargs < 2) {
|
||||
PyErr_SetString(PyExc_TypeError, "foreach() requires at least 2 arguments");
|
||||
return nullptr;
|
||||
}
|
||||
PyObject* fn = args[0];
|
||||
absl::InlinedVector<nb::object, 4> iterators;
|
||||
iterators.reserve(nargs - 1);
|
||||
for (Py_ssize_t i = 1; i < nargs; ++i) {
|
||||
PyObject* it = PyObject_GetIter(args[i]);
|
||||
if (!it) return nullptr;
|
||||
iterators.push_back(nb::steal<nb::object>(it));
|
||||
}
|
||||
|
||||
// The arguments we will pass to fn. We allocate space for one more argument
|
||||
// than we need at the start of the argument list so we can use
|
||||
// PY_VECTORCALL_ARGUMENTS_OFFSET which may speed up the callee.
|
||||
absl::InlinedVector<PyObject*, 4> values(nargs, nullptr);
|
||||
while (true) {
|
||||
absl::Cleanup values_cleanup = [&values]() {
|
||||
for (PyObject* v : values) {
|
||||
Py_XDECREF(v);
|
||||
v = nullptr;
|
||||
}
|
||||
};
|
||||
values[1] = PyIter_Next(iterators[0].ptr());
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
|
||||
if (values[1]) {
|
||||
for (size_t i = 1; i < iterators.size(); ++i) {
|
||||
values[i + 1] = PyIter_Next(iterators[i].ptr());
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
if (!values[i + 1]) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"foreach() argument %u is shorter than argument 1",
|
||||
i + 1);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No more elements should be left. Checks the other iterators are
|
||||
// exhausted.
|
||||
for (size_t i = 1; i < iterators.size(); ++i) {
|
||||
values[i + 1] = PyIter_Next(iterators[i].ptr());
|
||||
if (PyErr_Occurred()) return nullptr;
|
||||
if (values[i + 1]) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"foreach() argument %u is longer than argument 1",
|
||||
i + 1);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
Py_INCREF(Py_None);
|
||||
return Py_None;
|
||||
}
|
||||
|
||||
nb::object out = nb::steal<nb::object>(PyObject_Vectorcall(
|
||||
fn, &values[1], (nargs - 1) | PY_VECTORCALL_ARGUMENTS_OFFSET,
|
||||
/*kwnames=*/nullptr));
|
||||
if (PyErr_Occurred()) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
PyMethodDef foreach_def = {
|
||||
"foreach", reinterpret_cast<PyCFunction>(ForEach), METH_FASTCALL,
|
||||
"foreach() applies a function elementwise to one or more iterables, "
|
||||
"ignoring the return values and returns None. The iterables must all have "
|
||||
"the same lengths."};
|
||||
|
||||
// A variant of zip(...) that:
|
||||
// a) returns a list instead of an iterator, and
|
||||
// b) checks that the input iterables are of equal length.
|
||||
@ -224,6 +299,8 @@ NB_MODULE(utils, m) {
|
||||
nb::object module_name = m.attr("__name__");
|
||||
m.attr("safe_map") = nb::steal<nb::object>(
|
||||
PyCFunction_NewEx(&safe_map_def, /*self=*/nullptr, module_name.ptr()));
|
||||
m.attr("foreach") = nb::steal<nb::object>(
|
||||
PyCFunction_NewEx(&foreach_def, /*self=*/nullptr, module_name.ptr()));
|
||||
m.attr("safe_zip") = nb::steal<nb::object>(
|
||||
PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr()));
|
||||
|
||||
|
@ -30,7 +30,7 @@ jax.config.parse_flags_with_absl()
|
||||
|
||||
npr.seed(0)
|
||||
|
||||
from jax._src.util import unzip2, safe_zip, safe_map
|
||||
from jax._src.util import foreach, unzip2, safe_zip, safe_map
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
@ -87,10 +87,10 @@ def eval_fun(fun, *args):
|
||||
env[v] = x
|
||||
|
||||
env = {}
|
||||
map(write, fun.in_vars, args)
|
||||
foreach(write, fun.in_vars, args)
|
||||
for in_vars, out_vars, f in fun.eqns:
|
||||
out_vals = f(*map(read, in_vars))
|
||||
map(write, out_vars, out_vals)
|
||||
foreach(write, out_vars, out_vals)
|
||||
|
||||
return map(read, fun.out_vars)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user