diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 8a797e1f3..1ec8ad50b 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -51,7 +51,7 @@ from jax._src.tree_util import tree_map from jax._src.tree_util import tree_unflatten from jax._src.typing import Array from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip, - unzip3, weakref_lru_cache, HashableWrapper) + unzip3, weakref_lru_cache, HashableWrapper, foreach) source_info_util.register_exclusion(__file__) traceback_util.register_exclusion(__file__) @@ -413,8 +413,8 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value], def write_env(var: core.Var, val: Any): env[var] = val - map(write_env, jaxpr.constvars, consts) - map(write_env, jaxpr.invars, in_args) + foreach(write_env, jaxpr.constvars, consts) + foreach(write_env, jaxpr.invars, in_args) # interpreter loop for eqn in jaxpr.eqns: @@ -427,7 +427,7 @@ def checkify_jaxpr_flat(jaxpr: core.Jaxpr, consts: Sequence[core.Value], error, outvals = checkify_rule(error, enabled_errors, *invals, **eqn.params) if eqn.primitive.multiple_results: - map(write_env, eqn.outvars, outvals) + foreach(write_env, eqn.outvars, outvals) else: write_env(eqn.outvars[0], outvals) core.clean_up_dead_vars(eqn, env, last_used) diff --git a/jax/_src/core.py b/jax/_src/core.py index 20fe05656..167244064 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -50,7 +50,7 @@ from jax._src import source_info_util from jax._src.util import (safe_zip, safe_map, curry, tuple_insert, tuple_delete, cache, HashableFunction, HashableWrapper, weakref_lru_cache, - partition_list, StrictABCMeta) + partition_list, StrictABCMeta, foreach) import jax._src.pretty_printer as pp from jax._src.named_sharding import NamedSharding from jax._src.lib import jax_jit @@ -578,8 +578,8 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True) -> list[ env[v] = val env: dict[Var, Any] = {} - map(write, jaxpr.constvars, consts) - map(write, jaxpr.invars, args) + foreach(write, jaxpr.constvars, consts) + foreach(write, jaxpr.invars, args) lu = last_used(jaxpr) for eqn in jaxpr.eqns: subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) @@ -589,7 +589,7 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True) -> list[ traceback, name_stack=name_stack), eqn.ctx.manager: ans = eqn.primitive.bind(*subfuns, *map(read, eqn.invars), **bind_params) if eqn.primitive.multiple_results: - map(write, eqn.outvars, ans) + foreach(write, eqn.outvars, ans) else: write(eqn.outvars[0], ans) clean_up_dead_vars(eqn, env, lu) @@ -2837,7 +2837,7 @@ def _check_jaxpr( # Check out_type matches the let-binders' annotation (after substitution). out_type = substitute_vars_in_output_ty(out_type, eqn.invars, eqn.outvars) - map(write, eqn.outvars, out_type) + foreach(write, eqn.outvars, out_type) except JaxprTypeError as e: ctx, settings = ctx_factory() @@ -2848,7 +2848,7 @@ def _check_jaxpr( raise JaxprTypeError(msg, eqn_idx) from None # TODO(mattjj): include output type annotation on jaxpr and check it here - map(read, jaxpr.outvars) + foreach(read, jaxpr.outvars) def check_type( ctx_factory: Callable[[], tuple[JaxprPpContext, JaxprPpSettings]], diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 8d9ad53df..327528b69 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -39,7 +39,7 @@ from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal) from jax._src.dtypes import dtype, float0 from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name, as_hashable_function, weakref_lru_cache, - partition_list, subs_list2) + partition_list, subs_list2, foreach) zip = safe_zip map = safe_map @@ -344,10 +344,10 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack, primal_env[v] = val primal_env: dict[Any, Any] = {} - map(write_primal, jaxpr.constvars, consts) + foreach(write_primal, jaxpr.constvars, consts) # FIXME: invars can contain both primal and tangent values, and this line # forces primal_in to contain UndefinedPrimals for tangent values! - map(write_primal, jaxpr.invars, primals_in) + foreach(write_primal, jaxpr.invars, primals_in) # Start with a forward pass to evaluate any side-effect-free JaxprEqns that # only operate on primals. This is required to support primitives with @@ -367,7 +367,7 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack, traceback, name_stack=name_stack), eqn.ctx.manager: ans = eqn.primitive.bind(*subfuns, *map(read_primal, eqn.invars), **bind_params) if eqn.primitive.multiple_results: - map(write_primal, eqn.outvars, ans) + foreach(write_primal, eqn.outvars, ans) else: write_primal(eqn.outvars[0], ans) @@ -375,7 +375,7 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack, ctx = (source_info_util.transform_name_stack('transpose') if transform_stack else contextlib.nullcontext()) with ctx: - map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) + foreach(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) for eqn in lin_eqns[::-1]: if eqn.primitive.ref_primitive: if eqn.primitive is core.mutable_array_p: @@ -417,7 +417,7 @@ def backward_pass(jaxpr: core.Jaxpr, transform_stack, raise e from None cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out # FIXME: Some invars correspond to primals! - map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) + foreach(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) cotangents_out = map(read_cotangent, jaxpr.invars) return cotangents_out diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 695044252..36dcdca95 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -53,6 +53,7 @@ from jax._src.sharding import Sharding as JSharding from jax._src.sharding_impls import (AUTO, NamedSharding, modify_sdy_sharding_wrt_axis_types, SdyArraySharding, SdyArrayShardingList) +from jax._src.util import foreach from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension, xla_extension_version from jax._src.lib.mlir import dialects, ir, passmanager @@ -1941,8 +1942,8 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, assert len(args) == len(jaxpr.invars), (jaxpr, args) assert len(consts) == len(jaxpr.constvars), (jaxpr, consts) assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values) - map(write, jaxpr.constvars, consts) - map(write, jaxpr.invars, args) + foreach(write, jaxpr.constvars, consts) + foreach(write, jaxpr.invars, args) last_used = core.last_used(jaxpr) for eqn in jaxpr.eqns: in_nodes = map(read, eqn.invars) @@ -2009,7 +2010,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, f"{eqn}, got output {ans}") from e assert len(ans) == len(eqn.outvars), (ans, eqn) - map(write, eqn.outvars, out_nodes) + foreach(write, eqn.outvars, out_nodes) core.clean_up_dead_vars(eqn, env, last_used) return tuple(read(v) for v in jaxpr.outvars), tokens @@ -2101,11 +2102,11 @@ def lower_per_platform(ctx: LoweringRuleContext, # If there is a single rule left just apply the rule, without conditionals. if len(kept_rules) == 1: output = kept_rules[0](ctx, *rule_args, **rule_kwargs) - map( + foreach( lambda o: wrap_compute_type_in_place(ctx, o.owner), filter(_is_not_block_argument, flatten_ir_values(output)), ) - map( + foreach( lambda o: wrap_xla_metadata_in_place(ctx, o.owner), flatten_ir_values(output), ) @@ -2146,11 +2147,11 @@ def lower_per_platform(ctx: LoweringRuleContext, except TypeError as e: raise ValueError("Output of translation rule must be iterable: " f"{description}, got output {output}") from e - map( + foreach( lambda o: wrap_compute_type_in_place(ctx, o.owner), filter(_is_not_block_argument, out_nodes), ) - map( + foreach( lambda o: wrap_xla_metadata_in_place(ctx, o.owner), out_nodes, ) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 526cc3bf0..16e090dd4 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -47,7 +47,7 @@ from jax._src.tree_util import (PyTreeDef, treedef_tuple, from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list, merge_lists, partition_list, OrderedSet, as_hashable_function, weakref_lru_cache, subs_list, - HashableFunction) + HashableFunction, foreach) map, unsafe_map = safe_map, map @@ -1085,8 +1085,8 @@ def _partial_eval_jaxpr_custom_cached( newvar = core.gensym(suffix='_offload') known_eqns, staged_eqns = [], [] - map(write, in_unknowns, in_inst, jaxpr.invars) - map(partial(write, False, True), jaxpr.constvars) + foreach(write, in_unknowns, in_inst, jaxpr.invars) + foreach(partial(write, False, True), jaxpr.constvars) for eqn in jaxpr.eqns: unks_in, inst_in = unzip2(map(read, eqn.invars)) rule = partial_eval_jaxpr_custom_rules.get(eqn.primitive) @@ -1098,18 +1098,18 @@ def _partial_eval_jaxpr_custom_cached( residual_refs.add(r) else: residuals.add(r) - map(write, unks_out, inst_out, eqn.outvars) + foreach(write, unks_out, inst_out, eqn.outvars) elif any(unks_in): inputs = map(ensure_instantiated, inst_in, eqn.invars) staged_eqns.append(eqn.replace(invars=inputs)) - map(partial(write, True, True), eqn.outvars) + foreach(partial(write, True, True), eqn.outvars) else: known_eqns.append(eqn) # If it's an effectful primitive, we always to run and avoid staging it. policy = ensure_enum(saveable( eqn.primitive, *[x.aval for x in eqn.invars], **eqn.params)) if has_effects(eqn.effects) or isinstance(policy, SaveableType): - map(partial(write, False, False), eqn.outvars) + foreach(partial(write, False, False), eqn.outvars) elif isinstance(policy, Offloadable): # TODO(slebedev): This is a legit error which requires a BUILD fix. from jax._src.dispatch import device_put_p, TransferToMemoryKind, CopySemantics # pytype: disable=import-error @@ -1124,7 +1124,7 @@ def _partial_eval_jaxpr_custom_cached( JaxprEqnContext(None, False)) known_eqns.append(offload_eqn) # resvars are known and available in the backward jaxpr. - map(partial(write, False, True), resvars) + foreach(partial(write, False, True), resvars) residuals.update(resvars) reload_eqn = core.JaxprEqn( resvars, eqn.outvars, device_put_p, @@ -1135,12 +1135,12 @@ def _partial_eval_jaxpr_custom_cached( JaxprEqnContext(None, False)) staged_eqns.append(reload_eqn) # outvars are known and available in the backward jaxpr. - map(partial(write, False, True), eqn.outvars) + foreach(partial(write, False, True), eqn.outvars) else: assert isinstance(policy, RecomputeType) inputs = map(ensure_instantiated, inst_in, eqn.invars) staged_eqns.append(eqn.replace(invars=inputs)) - map(partial(write, False, True), eqn.outvars) + foreach(partial(write, False, True), eqn.outvars) unzipped = unzip2(map(read, jaxpr.outvars)) out_unknowns, out_inst = list(unzipped[0]), list(unzipped[1]) assert all(type(v) is Var for v in residuals), residuals @@ -1441,14 +1441,14 @@ def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...], env[x] = read(x) or b new_eqns = [] - map(write, jaxpr.outvars, used_outputs) + foreach(write, jaxpr.outvars, used_outputs) for eqn in jaxpr.eqns[::-1]: used_outs = map(read, eqn.outvars) rule = dce_rules.get(eqn.primitive, _default_dce_rule) used_ins, new_eqn = rule(used_outs, eqn) if new_eqn is not None: new_eqns.append(new_eqn) - map(write, eqn.invars, used_ins) + foreach(write, eqn.invars, used_ins) used_inputs = map(read, jaxpr.invars) used_inputs = map(op.or_, instantiate, used_inputs) @@ -2495,15 +2495,15 @@ def _eval_jaxpr_padded( def write(v, val) -> None: env[v] = val - map(write, jaxpr.constvars, consts) - map(write, jaxpr.invars, args) + foreach(write, jaxpr.constvars, consts) + foreach(write, jaxpr.invars, args) last_used = core.last_used(jaxpr) for eqn in jaxpr.eqns: in_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.invars] out_avals = [_substitute_axis_sizes(env, v.aval) for v in eqn.outvars] rule = padding_rules[eqn.primitive] outs = rule(in_avals, out_avals, *map(read, eqn.invars), **eqn.params) - map(write, eqn.outvars, outs) + foreach(write, eqn.outvars, outs) core.clean_up_dead_vars(eqn, env, last_used) return map(read, jaxpr.outvars) @@ -2580,7 +2580,7 @@ def inline_jaxpr_into_trace( src_ = (src if not eqn.source_info.name_stack else src.replace(name_stack=src.name_stack + eqn.source_info.name_stack)) trace.frame.add_eqn(eqn.replace(invars, outvars, source_info=src_)) - map(env.setdefault, eqn.outvars, outvars) + foreach(env.setdefault, eqn.outvars, outvars) tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars], [*consts, *arg_tracers])) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index d341fbedd..22d703945 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -71,7 +71,8 @@ from jax._src.sharding_impls import (PmapSharding, NamedSharding, PartitionSpec as P, canonicalize_sharding) from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape from jax._src.util import (NumpyComplexWarning, cache, canonicalize_axis, - safe_map, safe_zip, split_list, weakref_lru_cache) + safe_map, safe_zip, split_list, weakref_lru_cache, + foreach) _max = builtins.max _min = builtins.min @@ -106,7 +107,7 @@ def _validate_shapes(shapes: Sequence[Shape]): # pass dynamic shapes through unchecked return else: - map(_check_static_shape, shapes) + foreach(_check_static_shape, shapes) def _try_broadcast_shapes(*shapes: tuple[int, ...], name: str) -> tuple[int, ...]: """ diff --git a/jax/_src/pallas/fuser/fusable_dtype.py b/jax/_src/pallas/fuser/fusable_dtype.py index ab0075a73..e5bc9ab68 100644 --- a/jax/_src/pallas/fuser/fusable_dtype.py +++ b/jax/_src/pallas/fuser/fusable_dtype.py @@ -37,6 +37,7 @@ from jax._src.pallas.fuser import block_spec from jax._src.pallas.fuser.fusable import fusable_p from jax._src.state import discharge as state_discharge from jax._src.state import primitives as state_primitives +from jax._src.util import foreach # TODO(sharadmv): Enable type checking. # mypy: ignore-errors @@ -216,11 +217,11 @@ def physicalize_interp( def write_env(var: core.Var, val: Any): env[var] = val - map(write_env, jaxpr.constvars, consts) + foreach(write_env, jaxpr.constvars, consts) assert len(jaxpr.invars) == len( args ), f"Length mismatch: {jaxpr.invars} != {args}" - map(write_env, jaxpr.invars, args) + foreach(write_env, jaxpr.invars, args) for eqn in jaxpr.eqns: invals = list(map(read_env, eqn.invars)) @@ -248,7 +249,7 @@ def physicalize_interp( if eqn.primitive.multiple_results: assert len(outvals) == len(eqn.outvars) - map(write_env, eqn.outvars, outvals) + foreach(write_env, eqn.outvars, outvals) else: write_env(eqn.outvars[0], outvals) diff --git a/jax/_src/pallas/hlo_interpreter.py b/jax/_src/pallas/hlo_interpreter.py index d42402f81..8d7543b31 100644 --- a/jax/_src/pallas/hlo_interpreter.py +++ b/jax/_src/pallas/hlo_interpreter.py @@ -40,6 +40,7 @@ from jax._src.pallas import primitives from jax._src.state import discharge as state_discharge from jax._src import util from jax._src.util import ( + foreach, safe_map, safe_zip, split_list, @@ -199,8 +200,8 @@ def eval_jaxpr_recursive( env[v] = val env: dict[jax_core.Var, Any] = {} - map(write, jaxpr.constvars, consts) - map(write, jaxpr.invars, args) + foreach(write, jaxpr.constvars, consts) + foreach(write, jaxpr.invars, args) lu = jax_core.last_used(jaxpr) for eqn in jaxpr.eqns: in_vals = map(read, eqn.invars) @@ -216,7 +217,7 @@ def eval_jaxpr_recursive( subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) ans = eqn.primitive.bind(*subfuns, *in_vals, **bind_params) if eqn.primitive.multiple_results: - map(write, eqn.outvars, ans) + foreach(write, eqn.outvars, ans) else: write(eqn.outvars[0], ans) jax_core.clean_up_dead_vars(eqn, env, lu) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 7deb40163..4efb2b276 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -65,6 +65,7 @@ from jax._src.state import primitives as state_primitives from jax._src.state.types import RefBitcaster, RefReshaper from jax._src.state.utils import dtype_bitwidth from jax._src.typing import Array, DTypeLike +from jax._src.util import foreach from jax._src.util import safe_map from jax._src.util import safe_zip from jax._src.util import split_list @@ -950,7 +951,7 @@ def jaxpr_subcomp( for invar, bs in zip(jaxpr.invars, ctx.block_shapes): block_shape_env[invar] = bs - map(write_env, jaxpr.invars, args) + foreach(write_env, jaxpr.invars, args) initial_name_stack = [scope.name for scope in ctx.name_stack.stack] current_name_stack: list[str] = [] @@ -1011,7 +1012,7 @@ def jaxpr_subcomp( f"{eqn.primitive.name}. " "Please file an issue on https://github.com/jax-ml/jax/issues.") if eqn.primitive.multiple_results: - map(write_env, eqn.outvars, ans) + foreach(write_env, eqn.outvars, ans) else: write_env(eqn.outvars[0], ans) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 0333a9b03..5f0f83b4e 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -52,6 +52,7 @@ from jax._src.state import indexing from jax._src.state import primitives as sp from jax._src.state import types as state_types from jax._src.state.types import RefReshaper +from jax._src.util import foreach import jax.experimental.mosaic.gpu as mgpu from jax.experimental.mosaic.gpu import core as mgpu_core from jax.experimental.mosaic.gpu import profiler as mgpu_profiler @@ -738,8 +739,8 @@ def lower_jaxpr_to_mosaic_gpu( if val.type != mlir_dtype: raise AssertionError(f"Scalar type must match ShapedArray dtype, got: {val.type} != {mlir_dtype}") - map(write_env, jaxpr.constvars, consts) - map(lambda v, a: write_env(v, a, require_value=False), jaxpr.invars, args) + foreach(write_env, jaxpr.constvars, consts) + foreach(lambda v, a: write_env(v, a, require_value=False), jaxpr.invars, args) # TODO(justinfu): Handle transform scopes. last_local_name_stack: list[str] = [] named_regions = [] @@ -786,7 +787,7 @@ def lower_jaxpr_to_mosaic_gpu( f" {rule_ctx}\nWith inval types={inval_types}\nIn jaxpr:\n{jaxpr}" ) from e if eqn.primitive.multiple_results: - map(write_env, eqn.outvars, outvals) + foreach(write_env, eqn.outvars, outvals) else: write_env(eqn.outvars[0], outvals) while named_regions: # Drain the name stack. diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index c2080b9c6..a48fec61b 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -33,6 +33,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives +from jax._src.util import foreach from jax.experimental import pallas as pl import jax.numpy as jnp @@ -247,7 +248,7 @@ def emit_pipeline( it.islice(it.product(*map(range, grid)), max_concurrent_steps) ): indices = tuple(map(lambda i: jnp.asarray(i, dtype=jnp.int32), indices)) - map(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs) + foreach(lambda bref: bref.copy_in(step, indices, barrier_ref), in_brefs) # This is true if any of the outputs need to be transferred inside the loop. copies_out_in_loop = not all(bref.is_index_invariant for bref in out_brefs) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 57a76a37d..b1e1da34f 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -44,6 +44,7 @@ from jax._src.pallas import primitives from jax._src.state import discharge as state_discharge from jax._src.state import types as state_types from jax._src.util import ( + foreach, safe_map, safe_zip, split_list, @@ -978,7 +979,7 @@ def pallas_call_checkify_oob_grid(error: checkify.Error, for bm in grid_mapping.block_mappings] # We perform a dynamic slice on the i/o blocks, which will be checked by # checkify for OOB accesses. - map(hlo_interpreter._dynamic_slice, start_indices, block_shapes, + foreach(hlo_interpreter._dynamic_slice, start_indices, block_shapes, [*input_args, *output_args], is_indexing_dim) return (i + 1, hlo_interpreter._get_next_indices(grid, loop_idx)) def f(_): diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 0612ae5b8..f3a8dd175 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -53,6 +53,7 @@ from jax._src.pallas import utils as pallas_utils from jax._src.state import discharge from jax._src.state import indexing from jax._src.state import primitives as sp +from jax._src.util import foreach from jax._src.util import merge_lists from jax._src.util import partition_list from jax._src.util import split_list @@ -380,7 +381,7 @@ def lower_jaxpr_to_triton_ir( if block_info is not None: block_info_env[invar] = block_info - map(write_env, jaxpr.invars, args) + foreach(write_env, jaxpr.invars, args) for eqn in jaxpr.eqns: invals = map(read_env, eqn.invars) @@ -410,7 +411,7 @@ def lower_jaxpr_to_triton_ir( f"msg={e}" ) from e if eqn.primitive.multiple_results: - map(write_env, eqn.outvars, outvals) + foreach(write_env, eqn.outvars, outvals) else: write_env(eqn.outvars[0], outvals) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index dbfe46af4..7ab77d5b1 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -45,6 +45,7 @@ from jax._src.state.types import ( from jax._src.state.utils import bitcast, hoist_consts_to_refs from jax._src.typing import Array from jax._src.util import ( + foreach, merge_lists, partition_list, safe_map, @@ -140,10 +141,10 @@ def _eval_jaxpr_discharge_state( *args: Any): env = Environment({}) - map(env.write, jaxpr.constvars, consts) + foreach(env.write, jaxpr.constvars, consts) # Here some args may correspond to `Ref` avals but they'll be treated like # regular values in this interpreter. - map(env.write, jaxpr.invars, args) + foreach(env.write, jaxpr.invars, args) refs_to_discharge = {id(v.aval) for v, d in zip(jaxpr.invars, should_discharge) if d and isinstance(v.aval, AbstractRef)} @@ -195,7 +196,7 @@ def _eval_jaxpr_discharge_state( ans = eqn.primitive.bind(*subfuns, *map(env.read, eqn.invars), **bind_params) if eqn.primitive.multiple_results: - map(env.write, eqn.outvars, ans) + foreach(env.write, eqn.outvars, ans) else: env.write(eqn.outvars[0], ans) # By convention, we return the outputs of the jaxpr first and then the final diff --git a/jax/_src/util.py b/jax/_src/util.py index 0b4983c81..408106c12 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -90,6 +90,31 @@ if TYPE_CHECKING: else: safe_map = jaxlib_utils.safe_map +if TYPE_CHECKING: + @overload + def foreach(f: Callable[[T1], Any], __arg1: Iterable[T1]) -> None: ... + + @overload + def foreach(f: Callable[[T1, T2], Any], __arg1: Iterable[T1], __arg2: Iterable[T2]) -> None: ... + + @overload + def foreach(f: Callable[[T1, T2, T3], Any], __arg1: Iterable[T1], __arg2: Iterable[T2], __arg3: Iterable[T3]) -> None: ... + + @overload + def foreach(f: Callable[..., Any], __arg1: Iterable[Any], __arg2: Iterable[Any], __arg3: Iterable[Any], __arg4: Iterable[Any], *args) -> None: ... + + def foreach(f, *args): + safe_map(f, *args) + return None + +else: + # TODO(phawkins): remove after jaxlib 0.5.2 is the minimum. + if hasattr(jaxlib_utils, 'foreach'): + foreach = jaxlib_utils.foreach + else: + foreach = safe_map + + def unzip2(xys: Iterable[tuple[T1, T2]] ) -> tuple[tuple[T1, ...], tuple[T2, ...]]: """Unzip sequence of length-2 tuples into two tuples.""" diff --git a/jax/experimental/roofline/roofline.py b/jax/experimental/roofline/roofline.py index d5a1bb91d..6a7f2916b 100644 --- a/jax/experimental/roofline/roofline.py +++ b/jax/experimental/roofline/roofline.py @@ -28,6 +28,7 @@ from jax._src.api import make_jaxpr from jax._src.interpreters.partial_eval import dce_jaxpr from jax._src.mesh import AbstractMesh, Mesh from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map +from jax._src.util import foreach from jax.experimental import shard_map @@ -165,12 +166,12 @@ def _roofline_interpreter( jaxpr = jaxpr.jaxpr if isinstance(jaxpr, core.ClosedJaxpr) else jaxpr make_roofline_shape = lambda x: RooflineShape.from_aval(aval(x)) - map( + foreach( write, jaxpr.constvars, map(make_roofline_shape, jaxpr.constvars), ) - map(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars)) + foreach(write, jaxpr.invars, map(make_roofline_shape, jaxpr.invars)) last_used = core.last_used(jaxpr) for eqn in jaxpr.eqns: source_info = eqn.source_info.replace( @@ -210,7 +211,7 @@ def _roofline_interpreter( **eqn.params, ) - map(write, eqn.outvars, map(make_roofline_shape, eqn.outvars)) + foreach(write, eqn.outvars, map(make_roofline_shape, eqn.outvars)) core.clean_up_dead_vars(eqn, env, last_used) result += RooflineResult(peak_hbm_bytes=calculate_peak_hbm_bytes()) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 662207d36..c6a2d8d7a 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -58,7 +58,7 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import sdy from jax._src.util import (HashableFunction, HashablePartial, unzip2, as_hashable_function, memoize, partition_list, - split_list, subs_list2) + split_list, subs_list2, foreach) from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe @@ -648,15 +648,15 @@ def _check_rep(mesh: Mesh, jaxpr: core.Jaxpr, in_rep: Sequence[RepType] def write(v: core.Var, val: RepType) -> None: env[v] = val - map(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars)) - map(write, jaxpr.invars, in_rep) + foreach(write, jaxpr.constvars, [set(mesh.axis_names)] * len(jaxpr.constvars)) + foreach(write, jaxpr.invars, in_rep) last_used = core.last_used(jaxpr) for e in jaxpr.eqns: rule = _check_rules.get(e.primitive, partial(_rule_missing, e.primitive)) out_rep = rule(mesh, *map(read, e.invars), **e.params) if e.primitive.multiple_results: out_rep = [out_rep] * len(e.outvars) if type(out_rep) is set else out_rep - map(write, e.outvars, out_rep) + foreach(write, e.outvars, out_rep) else: write(e.outvars[0], out_rep) core.clean_up_dead_vars(e, env, last_used) diff --git a/jax/experimental/slab/djax.py b/jax/experimental/slab/djax.py index c989ede86..18f47515a 100644 --- a/jax/experimental/slab/djax.py +++ b/jax/experimental/slab/djax.py @@ -26,6 +26,7 @@ from jax import lax from jax._src import core from jax._src import util +from jax._src.util import foreach import jax.experimental.slab.slab as sl @@ -107,11 +108,11 @@ def eval_djaxpr(jaxpr: core.Jaxpr, slab: sl.Slab, *args: jax.Array | sl.SlabView def write(v, val): env[v] = val - map(write, jaxpr.invars, args) + foreach(write, jaxpr.invars, args) for eqn in jaxpr.eqns: invals = map(read, eqn.invars) slab, outvals = rules[eqn.primitive](slab, *invals, **eqn.params) - map(write, eqn.outvars, outvals) + foreach(write, eqn.outvars, outvals) return slab, map(read, jaxpr.outvars) rules: dict[core.Primitive, Callable] = {} diff --git a/jaxlib/utils.cc b/jaxlib/utils.cc index 6b612b26d..8b673ef68 100644 --- a/jaxlib/utils.cc +++ b/jaxlib/utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include +#include + #include "nanobind/nanobind.h" #include "absl/cleanup/cleanup.h" #include "absl/container/inlined_vector.h" @@ -125,6 +127,79 @@ PyMethodDef safe_map_def = { METH_FASTCALL, }; +// Similar to SafeMap, but ignores the return values of the function and returns +// None. +PyObject* ForEach(PyObject* self, PyObject* const* args, Py_ssize_t nargs) { + if (nargs < 2) { + PyErr_SetString(PyExc_TypeError, "foreach() requires at least 2 arguments"); + return nullptr; + } + PyObject* fn = args[0]; + absl::InlinedVector iterators; + iterators.reserve(nargs - 1); + for (Py_ssize_t i = 1; i < nargs; ++i) { + PyObject* it = PyObject_GetIter(args[i]); + if (!it) return nullptr; + iterators.push_back(nb::steal(it)); + } + + // The arguments we will pass to fn. We allocate space for one more argument + // than we need at the start of the argument list so we can use + // PY_VECTORCALL_ARGUMENTS_OFFSET which may speed up the callee. + absl::InlinedVector values(nargs, nullptr); + while (true) { + absl::Cleanup values_cleanup = [&values]() { + for (PyObject* v : values) { + Py_XDECREF(v); + v = nullptr; + } + }; + values[1] = PyIter_Next(iterators[0].ptr()); + if (PyErr_Occurred()) return nullptr; + + if (values[1]) { + for (size_t i = 1; i < iterators.size(); ++i) { + values[i + 1] = PyIter_Next(iterators[i].ptr()); + if (PyErr_Occurred()) return nullptr; + if (!values[i + 1]) { + PyErr_Format(PyExc_ValueError, + "foreach() argument %u is shorter than argument 1", + i + 1); + return nullptr; + } + } + } else { + // No more elements should be left. Checks the other iterators are + // exhausted. + for (size_t i = 1; i < iterators.size(); ++i) { + values[i + 1] = PyIter_Next(iterators[i].ptr()); + if (PyErr_Occurred()) return nullptr; + if (values[i + 1]) { + PyErr_Format(PyExc_ValueError, + "foreach() argument %u is longer than argument 1", + i + 1); + return nullptr; + } + } + Py_INCREF(Py_None); + return Py_None; + } + + nb::object out = nb::steal(PyObject_Vectorcall( + fn, &values[1], (nargs - 1) | PY_VECTORCALL_ARGUMENTS_OFFSET, + /*kwnames=*/nullptr)); + if (PyErr_Occurred()) { + return nullptr; + } + } +} + +PyMethodDef foreach_def = { + "foreach", reinterpret_cast(ForEach), METH_FASTCALL, + "foreach() applies a function elementwise to one or more iterables, " + "ignoring the return values and returns None. The iterables must all have " + "the same lengths."}; + // A variant of zip(...) that: // a) returns a list instead of an iterator, and // b) checks that the input iterables are of equal length. @@ -224,6 +299,8 @@ NB_MODULE(utils, m) { nb::object module_name = m.attr("__name__"); m.attr("safe_map") = nb::steal( PyCFunction_NewEx(&safe_map_def, /*self=*/nullptr, module_name.ptr())); + m.attr("foreach") = nb::steal( + PyCFunction_NewEx(&foreach_def, /*self=*/nullptr, module_name.ptr())); m.attr("safe_zip") = nb::steal( PyCFunction_NewEx(&safe_zip_def, /*self=*/nullptr, module_name.ptr())); diff --git a/tests/generated_fun_test.py b/tests/generated_fun_test.py index a288e1a5f..cdfeeba62 100644 --- a/tests/generated_fun_test.py +++ b/tests/generated_fun_test.py @@ -30,7 +30,7 @@ jax.config.parse_flags_with_absl() npr.seed(0) -from jax._src.util import unzip2, safe_zip, safe_map +from jax._src.util import foreach, unzip2, safe_zip, safe_map map = safe_map zip = safe_zip @@ -87,10 +87,10 @@ def eval_fun(fun, *args): env[v] = x env = {} - map(write, fun.in_vars, args) + foreach(write, fun.in_vars, args) for in_vars, out_vars, f in fun.eqns: out_vals = f(*map(read, in_vars)) - map(write, out_vars, out_vals) + foreach(write, out_vars, out_vals) return map(read, fun.out_vars)