Add a variant of safe_map() that has no return value, named foreach().

This avoids a bunch of list bookkeeping in cases where we are iterating only for a side effect and do not care about the results.

I would have named this iter() to match OCaml's list.iter(), but unfortunately iter() is a Python builtin.

PiperOrigin-RevId: 736859418
This commit is contained in:
Peter Hawkins 2025-03-14 07:42:07 -07:00 committed by jax authors
parent 074216e07a
commit 8ab33669e2
20 changed files with 184 additions and 70 deletions

View File

@ -51,7 +51,7 @@ from jax._src.tree_util import tree_map
from jax._src.tree_util import tree_unflatten
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
unzip3, weakref_lru_cache, HashableWrapper)
unzip3, weakref_lru_cache, HashableWrapper, foreach)
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
@ -413,8 +413,8 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
def write_env(var: core.Var, val: Any):
env[var] = val
map(write_env, jaxpr.constvars, consts)
map(write_env, jaxpr.invars, in_args)
foreach(write_env, jaxpr.constvars, consts)
foreach(write_env, jaxpr.invars, in_args)
# interpreter loop
for eqn in jaxpr.eqns:
@ -427,7 +427,7 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value],
error, outvals = checkify_rule(error, enabled_errors,
*invals, **eqn.params)
if eqn.primitive.multiple_results:
map(write_env, eqn.outvars, outvals)
foreach(write_env, eqn.outvars, outvals)
else:
write_env(eqn.outvars[0], outvals)
core.clean_up_dead_vars(eqn, env, last_used)

View File

@ -50,7 +50,7 @@ from jax._src import source_info_util
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
tuple_delete, cache,
HashableFunction, HashableWrapper, weakref_lru_cache,
partition_list, StrictABCMeta)
partition_list, StrictABCMeta, foreach)
import jax._src.pretty_printer as pp
from jax._src.named_sharding import NamedSharding
from jax._src.lib import jax_jit
@ -578,8 +578,8 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True) -> list[
env[v] = val
env: dict[Var, Any] = {}
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
foreach(write, jaxpr.constvars, consts)
foreach(write, jaxpr.invars, args)
lu = last_used(jaxpr)
for eqn in jaxpr.eqns:
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
@ -589,7 +589,7 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True) -> list[
traceback, name_stack=name_stack), eqn.ctx.manager:
ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params)
if eqn.primitive.multiple_results:
map(write, eqn.outvars, ans)
foreach(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
clean_up_dead_vars(eqn, env, lu)
@ -2837,7 +2837,7 @@ def _check_jaxpr(
# Check out_type matches the let-binders' annotation (after substitution).
out_type = substitute_vars_in_output_ty(out_type, eqn.invars, eqn.outvars)
map(write, eqn.outvars, out_type)
foreach(write, eqn.outvars, out_type)
except JaxprTypeError as e:
ctx, settings = ctx_factory()
@ -2848,7 +2848,7 @@ def _check_jaxpr(
raise JaxprTypeError(msg, eqn_idx) from None
# TODO(mattjj): include output type annotation on jaxpr and check it here
map(read, jaxpr.outvars)
foreach(read, jaxpr.outvars)
def check_type(
ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]],

View File

@ -39,7 +39,7 @@ from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal)
from jax._src.dtypes import dtype, float0
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
as_hashable_function, weakref_lru_cache,
partition_list, subs_list2)
partition_list, subs_list2, foreach)
zip = safe_zip
map = safe_map
@ -344,10 +344,10 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
primal_env[v] = val
primal_env: dict[Any, Any] = {}
map(write_primal, jaxpr.constvars, consts)
foreach(write_primal, jaxpr.constvars, consts)
# FIXME: invars can contain both primal and tangent values, and this line
# forces primal_in to contain UndefinedPrimals for tangent values!
map(write_primal, jaxpr.invars, primals_in)
foreach(write_primal, jaxpr.invars, primals_in)
# Start with a forward pass to evaluate any side-effect-free JaxprEqns that
# only operate on primals. This is required to support primitives with
@ -367,7 +367,7 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
traceback, name_stack=name_stack), eqn.ctx.manager:
ans = eqn.primitive.bind(*subfuns, *map(read_primal, eqn.invars), **bind_params)
if eqn.primitive.multiple_results:
map(write_primal, eqn.outvars, ans)
foreach(write_primal, eqn.outvars, ans)
else:
write_primal(eqn.outvars[0], ans)
@ -375,7 +375,7 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
ctx = (source_info_util.transform_name_stack('transpose') if transform_stack
else contextlib.nullcontext())
with ctx:
map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
foreach(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
for eqn in lin_eqns[::-1]:
if eqn.primitive.ref_primitive:
if eqn.primitive is core.mutable_array_p:
@ -417,7 +417,7 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack,
raise e from None
cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
# FIXME: Some invars correspond to primals!
map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
foreach(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
cotangents_out = map(read_cotangent, jaxpr.invars)
return cotangents_out

View File

@ -53,6 +53,7 @@ from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import (AUTO, NamedSharding,
modify_sdy_sharding_wrt_axis_types,
SdyArraySharding, SdyArrayShardingList)
from jax._src.util import foreach
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension, xla_extension_version
from jax._src.lib.mlir import dialects, ir, passmanager
@ -1941,8 +1942,8 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
assert len(args) == len(jaxpr.invars), (jaxpr, args)
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values)
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
foreach(write, jaxpr.constvars, consts)
foreach(write, jaxpr.invars, args)
last_used = core.last_used(jaxpr)
for eqn in jaxpr.eqns:
in_nodes = map(read, eqn.invars)
@ -2009,7 +2010,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
f"{eqn}, got output {ans}") from e
assert len(ans) == len(eqn.outvars), (ans, eqn)
map(write, eqn.outvars, out_nodes)
foreach(write, eqn.outvars, out_nodes)
core.clean_up_dead_vars(eqn, env, last_used)
return tuple(read(v) for v in jaxpr.outvars), tokens
@ -2101,11 +2102,11 @@ def lower_per_platform(ctx: LoweringRuleContext,
# If there is a single rule left just apply the rule, without conditionals.
if len(kept_rules) == 1:
output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
map(
foreach(
lambda o: wrap_compute_type_in_place(ctx, o.owner),
filter(_is_not_block_argument, flatten_ir_values(output)),
)
map(
foreach(
lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
flatten_ir_values(output),
)
@ -2146,11 +2147,11 @@ def lower_per_platform(ctx: LoweringRuleContext,
except TypeError as e:
raise ValueError("Output of translation rule must be iterable: "
f"{description}, got output {output}") from e
map(
foreach(
lambda o: wrap_compute_type_in_place(ctx, o.owner),
filter(_is_not_block_argument, out_nodes),
)
map(
foreach(
lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
out_nodes,
)

View File

@ -47,7 +47,7 @@ from jax._src.tree_util import (PyTreeDef, treedef_tuple,
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, OrderedSet,
as_hashable_function, weakref_lru_cache, subs_list,
HashableFunction)
HashableFunction, foreach)
map, unsafe_map = safe_map, map
@ -1085,8 +1085,8 @@ def _partial_eval_jaxpr_custom_cached(
newvar = core.gensym(suffix='_offload')
known_eqns, staged_eqns = [], []
map(write, in_unknowns, in_inst, jaxpr.invars)
map(partial(write, False, True), jaxpr.constvars)
foreach(write, in_unknowns, in_inst, jaxpr.invars)
foreach(partial(write, False, True), jaxpr.constvars)
for eqn in jaxpr.eqns:
unks_in, inst_in = unzip2(map(read, eqn.invars))
rule = partial_eval_jaxpr_custom_rules.get(eqn.primitive)
@ -1098,18 +1098,18 @@ def _partial_eval_jaxpr_custom_cached(
residual_refs.add(r)
else:
residuals.add(r)
map(write, unks_out, inst_out, eqn.outvars)
foreach(write, unks_out, inst_out, eqn.outvars)
elif any(unks_in):
inputs = map(ensure_instantiated, inst_in, eqn.invars)
staged_eqns.append(eqn.replace(invars=inputs))
map(partial(write, True, True), eqn.outvars)
foreach(partial(write, True, True), eqn.outvars)
else:
known_eqns.append(eqn)
# If it's an effectful primitive, we always to run and avoid staging it.
policy = ensure_enum(saveable(
eqn.primitive, *[x.aval for x in eqn.invars], **eqn.params))
if has_effects(eqn.effects) or isinstance(policy, SaveableType):
map(partial(write, False, False), eqn.outvars)
foreach(partial(write, False, False), eqn.outvars)
elif isinstance(policy, Offloadable):
# TODO(slebedev): This is a legit error which requires a BUILD fix.
from jax._src.dispatch import device_put_p, TransferToMemoryKind, CopySemantics # pytype: disable=import-error
@ -1124,7 +1124,7 @@ def _partial_eval_jaxpr_custom_cached(
JaxprEqnContext(None, False))
known_eqns.append(offload_eqn)
# resvars are known and available in the backward jaxpr.
map(partial(write, False, True), resvars)
foreach(partial(write, False, True), resvars)
residuals.update(resvars)
reload_eqn = core.JaxprEqn(
resvars, eqn.outvars, device_put_p,
@ -1135,12 +1135,12 @@ def _partial_eval_jaxpr_custom_cached(
JaxprEqnContext(None, False))
staged_eqns.append(reload_eqn)
# outvars are known and available in the backward jaxpr.
map(partial(write, False, True), eqn.outvars)
foreach(partial(write, False, True), eqn.outvars)
else:
assert isinstance(policy, RecomputeType)
inputs = map(ensure_instantiated, inst_in, eqn.invars)
staged_eqns.append(eqn.replace(invars=inputs))
map(partial(write, False, True), eqn.outvars)
foreach(partial(write, False, True), eqn.outvars)
unzipped = unzip2(map(read, jaxpr.outvars))
out_unknowns, out_inst = list(unzipped[0]), list(unzipped[1])
assert all(type(v) is Var for v in residuals), residuals
@ -1441,14 +1441,14 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
env[x] = read(x) or b
new_eqns = []
map(write, jaxpr.outvars, used_outputs)
foreach(write, jaxpr.outvars, used_outputs)
for eqn in jaxpr.eqns[::-1]:
used_outs = map(read, eqn.outvars)
rule = dce_rules.get(eqn.primitive, _default_dce_rule)
used_ins, new_eqn = rule(used_outs, eqn)
if new_eqn is not None:
new_eqns.append(new_eqn)
map(write, eqn.invars, used_ins)
foreach(write, eqn.invars, used_ins)
used_inputs = map(read, jaxpr.invars)
used_inputs = map(op.or_, instantiate, used_inputs)
@ -2495,15 +2495,15 @@ def _eval_jaxpr_padded(
def write(v, val) -> None:
env[v] = val
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
foreach(write, jaxpr.constvars, consts)
foreach(write, jaxpr.invars, args)
last_used = core.last_used(jaxpr)
for eqn in jaxpr.eqns:
in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars]
out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars]
rule = padding_rules[eqn.primitive]
outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params)
map(write, eqn.outvars, outs)
foreach(write, eqn.outvars, outs)
core.clean_up_dead_vars(eqn, env, last_used)
return map(read, jaxpr.outvars)
@ -2580,7 +2580,7 @@ def inline_jaxpr_into_trace(
src_ = (src if not eqn.source_info.name_stack else
src.replace(name_stack=src.name_stack + eqn.source_info.name_stack))
trace.frame.add_eqn(eqn.replace(invars, outvars, source_info=src_))
map(env.setdefault, eqn.outvars, outvars)
foreach(env.setdefault, eqn.outvars, outvars)
tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars],
[*consts, *arg_tracers]))

View File

@ -71,7 +71,8 @@ from jax._src.sharding_impls import (PmapSharding, NamedSharding,
PartitionSpec as P, canonicalize_sharding)
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis,
safe_map, safe_zip, split_list, weakref_lru_cache)
safe_map, safe_zip, split_list, weakref_lru_cache,
foreach)
_max = builtins.max
_min = builtins.min
@ -106,7 +107,7 @@ def _validate_shapes(shapes: Sequence[Shape]):
# pass dynamic shapes through unchecked
return
else:
map(_check_static_shape, shapes)
foreach(_check_static_shape, shapes)
def _try_broadcast_shapes(*shapes: tuple[int, ...], name: str) -> tuple[int, ...]:
"""

View File

@ -37,6 +37,7 @@ from jax._src.pallas.fuser import block_spec
from jax._src.pallas.fuser.fusable import fusable_p
from jax._src.state import discharge as state_discharge
from jax._src.state import primitives as state_primitives
from jax._src.util import foreach
# TODO(sharadmv): Enable type checking.
# mypy: ignore-errors
@ -216,11 +217,11 @@ def physicalize_interp(
def write_env(var: core.Var, val: Any):
env[var] = val
map(write_env, jaxpr.constvars, consts)
foreach(write_env, jaxpr.constvars, consts)
assert len(jaxpr.invars) == len(
args
), f"Length mismatch: {jaxpr.invars} != {args}"
map(write_env, jaxpr.invars, args)
foreach(write_env, jaxpr.invars, args)
for eqn in jaxpr.eqns:
invals = list(map(read_env, eqn.invars))
@ -248,7 +249,7 @@ def physicalize_interp(
if eqn.primitive.multiple_results:
assert len(outvals) == len(eqn.outvars)
map(write_env, eqn.outvars, outvals)
foreach(write_env, eqn.outvars, outvals)
else:
write_env(eqn.outvars[0], outvals)

View File

@ -40,6 +40,7 @@ from jax._src.pallas import primitives
from jax._src.state import discharge as state_discharge
from jax._src import util
from jax._src.util import (
foreach,
safe_map,
safe_zip,
split_list,
@ -199,8 +200,8 @@ def eval_jaxpr_recursive(
env[v] = val
env: dict[jax_core.Var, Any] = {}
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
foreach(write, jaxpr.constvars, consts)
foreach(write, jaxpr.invars, args)
lu = jax_core.last_used(jaxpr)
for eqn in jaxpr.eqns:
in_vals = map(read, eqn.invars)
@ -216,7 +217,7 @@ def eval_jaxpr_recursive(
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
ans = eqn.primitive.bind(*subfuns, *in_vals, **bind_params)
if eqn.primitive.multiple_results:
map(write, eqn.outvars, ans)
foreach(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
jax_core.clean_up_dead_vars(eqn, env, lu)

View File

@ -65,6 +65,7 @@ from jax._src.state import primitives as state_primitives
from jax._src.state.types import RefBitcaster, RefReshaper
from jax._src.state.utils import dtype_bitwidth
from jax._src.typing import Array, DTypeLike
from jax._src.util import foreach
from jax._src.util import safe_map
from jax._src.util import safe_zip
from jax._src.util import split_list
@ -950,7 +951,7 @@ def jaxpr_subcomp(
for invar, bs in zip(jaxpr.invars, ctx.block_shapes):
block_shape_env[invar] = bs
map(write_env, jaxpr.invars, args)
foreach(write_env, jaxpr.invars, args)
initial_name_stack = [scope.name for scope in ctx.name_stack.stack]
current_name_stack: list[str] = []
@ -1011,7 +1012,7 @@ def jaxpr_subcomp(
f"{eqn.primitive.name}. "
"Please file an issue on https://github.com/jax-ml/jax/issues.")
if eqn.primitive.multiple_results:
map(write_env, eqn.outvars, ans)
foreach(write_env, eqn.outvars, ans)
else:
write_env(eqn.outvars[0], ans)

View File

@ -52,6 +52,7 @@ from jax._src.state import indexing
from jax._src.state import primitives as sp
from jax._src.state import types as state_types
from jax._src.state.types import RefReshaper
from jax._src.util import foreach
import jax.experimental.mosaic.gpu as mgpu
from jax.experimental.mosaic.gpu import core as mgpu_core
from jax.experimental.mosaic.gpu import profiler as mgpu_profiler
@ -738,8 +739,8 @@ def lower_jaxpr_to_mosaic_gpu(
if val.type != mlir_dtype:
raise AssertionError(f"Scalar type must match ShapedArray dtype, got: {val.type} != {mlir_dtype}")
map(write_env, jaxpr.constvars, consts)
map(lambda v, a: write_env(v, a, require_value=False), jaxpr.invars, args)
foreach(write_env, jaxpr.constvars, consts)
foreach(lambda v, a: write_env(v, a, require_value=False), jaxpr.invars, args)
# TODO(justinfu): Handle transform scopes.
last_local_name_stack: list[str] = []
named_regions = []
@ -786,7 +787,7 @@ def lower_jaxpr_to_mosaic_gpu(
f" {rule_ctx}\nWith inval types={inval_types}\nIn jaxpr:\n{jaxpr}"
) from e
if eqn.primitive.multiple_results:
map(write_env, eqn.outvars, outvals)
foreach(write_env, eqn.outvars, outvals)
else:
write_env(eqn.outvars[0], outvals)
while named_regions: # Drain the name stack.

View File

@ -33,6 +33,7 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.pallas import core as pallas_core
from jax._src.pallas.mosaic_gpu import core as gpu_core
from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives
from jax._src.util import foreach
from jax.experimental import pallas as pl
import jax.numpy as jnp
@ -247,7 +248,7 @@ def emit_pipeline(
it.islice(it.product(*map(range, grid)), max_concurrent_steps)
):
indices = tuple(map(lambda i: jnp.asarray(i, dtype=jnp.int32), indices))
map(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs)
foreach(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs)
# This is true if any of the outputs need to be transferred inside the loop.
copies_out_in_loop = not all(bref.is_index_invariant for bref in out_brefs)

View File

@ -44,6 +44,7 @@ from jax._src.pallas import primitives
from jax._src.state import discharge as state_discharge
from jax._src.state import types as state_types
from jax._src.util import (
foreach,
safe_map,
safe_zip,
split_list,
@ -978,7 +979,7 @@ def pallas_call_checkify_oob_grid(error: checkify.Error,
for bm in grid_mapping.block_mappings]
# We perform a dynamic slice on the i/o blocks, which will be checked by
# checkify for OOB accesses.
map(hlo_interpreter._dynamic_slice, start_indices, block_shapes,
foreach(hlo_interpreter._dynamic_slice, start_indices, block_shapes,
[*input_args, *output_args], is_indexing_dim)
return (i + 1, hlo_interpreter._get_next_indices(grid, loop_idx))
def f(_):

View File

@ -53,6 +53,7 @@ from jax._src.pallas import utils as pallas_utils
from jax._src.state import discharge
from jax._src.state import indexing
from jax._src.state import primitives as sp
from jax._src.util import foreach
from jax._src.util import merge_lists
from jax._src.util import partition_list
from jax._src.util import split_list
@ -380,7 +381,7 @@ def lower_jaxpr_to_triton_ir(
if block_info is not None:
block_info_env[invar] = block_info
map(write_env, jaxpr.invars, args)
foreach(write_env, jaxpr.invars, args)
for eqn in jaxpr.eqns:
invals = map(read_env, eqn.invars)
@ -410,7 +411,7 @@ def lower_jaxpr_to_triton_ir(
f"msg={e}"
) from e
if eqn.primitive.multiple_results:
map(write_env, eqn.outvars, outvals)
foreach(write_env, eqn.outvars, outvals)
else:
write_env(eqn.outvars[0], outvals)

View File

@ -45,6 +45,7 @@ from jax._src.state.types import (
from jax._src.state.utils import bitcast, hoist_consts_to_refs
from jax._src.typing import Array
from jax._src.util import (
foreach,
merge_lists,
partition_list,
safe_map,
@ -140,10 +141,10 @@ def _eval_jaxpr_discharge_state(
*args: Any):
env = Environment({})
map(env.write, jaxpr.constvars, consts)
foreach(env.write, jaxpr.constvars, consts)
# Here some args may correspond to `Ref` avals but they'll be treated like
# regular values in this interpreter.
map(env.write, jaxpr.invars, args)
foreach(env.write, jaxpr.invars, args)
refs_to_discharge = {id(v.aval) for v, d in zip(jaxpr.invars, should_discharge)
if d and isinstance(v.aval, AbstractRef)}
@ -195,7 +196,7 @@ def _eval_jaxpr_discharge_state(
ans = eqn.primitive.bind(*subfuns, *map(env.read, eqn.invars),
**bind_params)
if eqn.primitive.multiple_results:
map(env.write, eqn.outvars, ans)
foreach(env.write, eqn.outvars, ans)
else:
env.write(eqn.outvars[0], ans)
# By convention, we return the outputs of the jaxpr first and then the final

View File

@ -90,6 +90,31 @@ if TYPE_CHECKING:
else:
safe_map = jaxlib_utils.safe_map
if TYPE_CHECKING:
@overload
def foreach(f: Callable[[T1], Any], __arg1: Iterable[T1]) -> None: ...
@overload
def foreach(f: Callable[[T1, T2], Any], __arg1: Iterable[T1], __arg2: Iterable[T2]) -> None: ...
@overload
def foreach(f: Callable[[T1, T2, T3], Any], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> None: ...
@overload
def foreach(f: Callable[..., Any], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> None: ...
def foreach(f, *args):
safe_map(f, *args)
return None
else:
# TODO(phawkins): remove after jaxlib 0.5.2 is the minimum.
if hasattr(jaxlib_utils, 'foreach'):
foreach = jaxlib_utils.foreach
else:
foreach = safe_map
def unzip2(xys: Iterable[tuple[T1, T2]]
) -> tuple[tuple[T1, ...], tuple[T2, ...]]:
"""Unzip sequence of length-2 tuples into two tuples."""

View File

@ -28,6 +28,7 @@ from jax._src.api import make_jaxpr
from jax._src.interpreters.partial_eval import dce_jaxpr
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map
from jax._src.util import foreach
from jax.experimental import shard_map
@ -165,12 +166,12 @@ def _roofline_interpreter(
jaxpr = jaxpr.jaxpr if isinstance(jaxpr, core.ClosedJaxpr) else jaxpr
make_roofline_shape = lambda x: RooflineShape.from_aval(aval(x))
map(
foreach(
write,
jaxpr.constvars,
map(make_roofline_shape, jaxpr.constvars),
)
map(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars))
foreach(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars))
last_used = core.last_used(jaxpr)
for eqn in jaxpr.eqns:
source_info = eqn.source_info.replace(
@ -210,7 +211,7 @@ def _roofline_interpreter(
**eqn.params,
)
map(write, eqn.outvars, map(make_roofline_shape, eqn.outvars))
foreach(write, eqn.outvars, map(make_roofline_shape, eqn.outvars))
core.clean_up_dead_vars(eqn, env, last_used)
result += RooflineResult(peak_hbm_bytes=calculate_peak_hbm_bytes())

View File

@ -58,7 +58,7 @@ from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import sdy
from jax._src.util import (HashableFunction, HashablePartial, unzip2,
as_hashable_function, memoize, partition_list,
split_list, subs_list2)
split_list, subs_list2, foreach)
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
@ -648,15 +648,15 @@ def _check_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[RepType]
def write(v: core.Var, val: RepType) -> None:
env[v] = val
map(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars))
map(write, jaxpr.invars, in_rep)
foreach(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars))
foreach(write, jaxpr.invars, in_rep)
last_used = core.last_used(jaxpr)
for e in jaxpr.eqns:
rule = _check_rules.get(e.primitive, partial(_rule_missing, e.primitive))
out_rep = rule(mesh, *map(read, e.invars), **e.params)
if e.primitive.multiple_results:
out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep
map(write, e.outvars, out_rep)
foreach(write, e.outvars, out_rep)
else:
write(e.outvars[0], out_rep)
core.clean_up_dead_vars(e, env, last_used)

View File

@ -26,6 +26,7 @@ from jax import lax
from jax._src import core
from jax._src import util
from jax._src.util import foreach
import jax.experimental.slab.slab as sl
@ -107,11 +108,11 @@ def eval_djaxpr(jaxpr: core.Jaxpr, slab: sl.Slab, *args: jax.Array | sl.SlabView
def write(v, val):
env[v] = val
map(write, jaxpr.invars, args)
foreach(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
invals = map(read, eqn.invars)
slab, outvals = rules[eqn.primitive](slab, *invals, **eqn.params)
map(write, eqn.outvars, outvals)
foreach(write, eqn.outvars, outvals)
return slab, map(read, jaxpr.outvars)
rules: dict[core.Primitive, Callable] = {}

View File

@ -15,6 +15,8 @@ limitations under the License.
#include <Python.h>
#include <cstddef>
#include "nanobind/nanobind.h"
#include "absl/cleanup/cleanup.h"
#include "absl/container/inlined_vector.h"
@ -125,6 +127,79 @@ PyMethodDef safe_map_def = {
METH_FASTCALL,
};
// Similar to SafeMap, but ignores the return values of the function and returns
// None.
PyObject* ForEach(PyObject* self, PyObject* const* args, Py_ssize_t nargs) {
if (nargs < 2) {
PyErr_SetString(PyExc_TypeError, "foreach() requires at least 2 arguments");
return nullptr;
}
PyObject* fn = args[0];
absl::InlinedVector<nb::object, 4> iterators;
iterators.reserve(nargs - 1);
for (Py_ssize_t i = 1; i < nargs; ++i) {
PyObject* it = PyObject_GetIter(args[i]);
if (!it) return nullptr;
iterators.push_back(nb::steal<nb::object>(it));
}
// The arguments we will pass to fn. We allocate space for one more argument
// than we need at the start of the argument list so we can use
// PY_VECTORCALL_ARGUMENTS_OFFSET which may speed up the callee.
absl::InlinedVector<PyObject*, 4> values(nargs, nullptr);
while (true) {
absl::Cleanup values_cleanup = [&values]() {
for (PyObject* v : values) {
Py_XDECREF(v);
v = nullptr;
}
};
values[1] = PyIter_Next(iterators[0].ptr());
if (PyErr_Occurred()) return nullptr;
if (values[1]) {
for (size_t i = 1; i < iterators.size(); ++i) {
values[i + 1] = PyIter_Next(iterators[i].ptr());
if (PyErr_Occurred()) return nullptr;
if (!values[i + 1]) {
PyErr_Format(PyExc_ValueError,
"foreach() argument %u is shorter than argument 1",
i + 1);
return nullptr;
}
}
} else {
// No more elements should be left. Checks the other iterators are
// exhausted.
for (size_t i = 1; i < iterators.size(); ++i) {
values[i + 1] = PyIter_Next(iterators[i].ptr());
if (PyErr_Occurred()) return nullptr;
if (values[i + 1]) {
PyErr_Format(PyExc_ValueError,
"foreach() argument %u is longer than argument 1",
i + 1);
return nullptr;
}
}
Py_INCREF(Py_None);
return Py_None;
}
nb::object out = nb::steal<nb::object>(PyObject_Vectorcall(
fn, &values[1], (nargs - 1) | PY_VECTORCALL_ARGUMENTS_OFFSET,
/*kwnames=*/nullptr));
if (PyErr_Occurred()) {
return nullptr;
}
}
}
PyMethodDef foreach_def = {
"foreach", reinterpret_cast<PyCFunction>(ForEach), METH_FASTCALL,
"foreach() applies a function elementwise to one or more iterables, "
"ignoring the return values and returns None. The iterables must all have "
"the same lengths."};
// A variant of zip(...) that:
// a) returns a list instead of an iterator, and
// b) checks that the input iterables are of equal length.
@ -224,6 +299,8 @@ NB_MODULE(utils, m) {
nb::object module_name = m.attr("__name__");
m.attr("safe_map") = nb::steal<nb::object>(
PyCFunction_NewEx(&safe_map_def, /*self=*/nullptr, module_name.ptr()));
m.attr("foreach") = nb::steal<nb::object>(
PyCFunction_NewEx(&foreach_def, /*self=*/nullptr, module_name.ptr()));
m.attr("safe_zip") = nb::steal<nb::object>(
PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr()));

View File

@ -30,7 +30,7 @@ jax.config.parse_flags_with_absl()
npr.seed(0)
from jax._src.util import unzip2, safe_zip, safe_map
from jax._src.util import foreach, unzip2, safe_zip, safe_map
map = safe_map
zip = safe_zip
@ -87,10 +87,10 @@ def eval_fun(fun, *args):
env[v] = x
env = {}
map(write, fun.in_vars, args)
foreach(write, fun.in_vars, args)
for in_vars, out_vars, f in fun.eqns:
out_vals = f(*map(read, in_vars))
map(write, out_vars, out_vals)
foreach(write, out_vars, out_vals)
return map(read, fun.out_vars)