From 148774587ace00270d1f8f6c9c9ef733f45deedb Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 27 Feb 2023 11:37:10 -0800 Subject: [PATCH] Remove circular dependency between source_info_util and util. Move util.new_name_stack into source_info_util. Replace uses of util.extend_name_stack with stack.extend(). PiperOrigin-RevId: 512685810 --- jax/_src/api.py | 7 ++++--- jax/_src/dispatch.py | 2 +- jax/_src/interpreters/mlir.py | 3 +-- jax/_src/interpreters/pxla.py | 12 ++++++------ jax/_src/interpreters/xla.py | 5 ++--- jax/_src/lax/control_flow/conditionals.py | 7 +++---- jax/_src/lax/control_flow/loops.py | 11 ++++------- jax/_src/lax/lax.py | 3 ++- jax/_src/lax/slicing.py | 4 +++- jax/_src/maps.py | 11 ++++------- jax/_src/source_info_util.py | 8 ++++++++ jax/_src/util.py | 12 ------------ jax/interpreters/pxla.py | 1 - 13 files changed, 38 insertions(+), 48 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 1a1e10948..316221b5a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -66,8 +66,8 @@ from jax._src.sharding import PmapSharding from jax._src.traceback_util import api_boundary from jax._src.tree_util import broadcast_prefix, _generate_key_paths from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list, - extend_name_stack, new_name_stack, wrap_name, cache, - wraps, HashableFunction, weakref_lru_cache) + wrap_name, cache, wraps, HashableFunction, + weakref_lru_cache) # Unused imports to be exported @@ -1083,7 +1083,8 @@ def xla_computation(fun: Callable, backend_or_name=backend, platform=platform, axis_context=mlir.ReplicaAxisContext(axis_env_), - name_stack=new_name_stack(wrap_name(fun_name, "xla_computation")), + name_stack=source_info_util.new_name_stack( + wrap_name(fun_name, "xla_computation")), donated_args=donated_invars, arg_shardings=(None if in_parts_flat is None else map( xla.sharding_to_proto, in_parts_flat)), diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index b08cdf526..0958c35e3 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -499,7 +499,7 @@ def lower_xla_callable( # pass long arg lists as tuple for TPU tuple_args = should_tuple_args(len(abstract_args), backend.platform) axis_env = xla.AxisEnv(nreps, (), ()) - name_stack = util.new_name_stack(util.wrap_name(name, 'jit')) + name_stack = source_info_util.new_name_stack(util.wrap_name(name, 'jit')) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) closed_out_type = [ (a.update(shape=tuple(pe.InDBIdx(d.val - len(consts)) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 13aa72f9f..6031c0a81 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1014,8 +1014,7 @@ def lower_jaxpr_to_fun( args.append(hlo.CreateTokenOp().results) else: args.append(arg) - callee_name_stack = util.extend_name_stack( - ctx.name_stack, util.wrap_name(name, api_name)) + callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name)) out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, tokens_in, map(ir_constants, jaxpr.consts), *args, dim_var_values=dim_var_values) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 50cbc22cc..5988ad19b 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -78,7 +78,7 @@ from jax._src.lib import pmap_lib from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list, - new_name_stack, wrap_name, assert_unreachable, + wrap_name, assert_unreachable, tuple_insert, tuple_delete, distributed_debug_log, unzip2, HashableFunction) @@ -1478,7 +1478,7 @@ def lower_parallel_callable( axis_env = xla.AxisEnv( replicas.num_global_replicas, (axis_name,), (global_axis_size,)) - name_stack = new_name_stack(wrap_name(name, 'pmap')) + name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap')) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) replicated_args = [axis is None for axis in in_axes] tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals), @@ -2311,8 +2311,8 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore sub_ctx = ctx.module_context.replace( axis_context=mlir.ReplicaAxisContext(new_env), - name_stack=util.extend_name_stack(ctx.module_context.name_stack, - util.wrap_name(name, 'pmap'))) + name_stack=ctx.module_context.name_stack.extend( + util.wrap_name(name, 'pmap'))) sharded_outs, _ = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, mlir.TokenSet(), (), *in_nodes_sharded, dim_var_values=ctx.dim_var_values) @@ -2852,7 +2852,7 @@ def lower_sharding_computation( the singleton _UNSPECIFIED to all out_avals. """ # 1. Trace to jaxpr and preprocess/verify it - name_stack = new_name_stack(wrap_name(fun_name, api_name)) + name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name)) with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} " "in {elapsed_time} sec", @@ -3068,7 +3068,7 @@ def lower_mesh_computation( in_is_global: Sequence[bool]) -> MeshComputation: assert not mesh.empty backend = xb.get_device_backend(mesh.devices.flat[0]) - name_stack = new_name_stack(wrap_name(fun_name, api_name)) + name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name)) auto_spmd_lowering = check_if_any_auto(in_shardings + out_shardings) # type: ignore diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index e397e90d0..74ccf9519 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -37,8 +37,7 @@ from jax._src import source_info_util from jax._src.abstract_arrays import numpy_scalar_types from jax._src.core import ConcreteArray, ShapedArray from jax._src.interpreters import ad -from jax._src.util import (prod, new_name_stack, safe_zip, safe_map, - partition_list) +from jax._src.util import (prod, safe_zip, safe_map, partition_list) from jax._src.typing import Shape @@ -251,7 +250,7 @@ def primitive_subcomputation(platform: str, axis_env: 'AxisEnv', rule = _translations[prim] ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env, - name_stack=new_name_stack()) + name_stack=source_info_util.new_name_stack()) ans = rule(ctx, avals_in, avals_out, *xla_args, **params) if prim.multiple_results: diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 02de3ee49..649079cd9 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -39,8 +39,7 @@ from jax._src import state from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects from jax._src.lax import lax from jax._src.traceback_util import api_boundary -from jax._src.util import (safe_map, extend_name_stack, split_list, - partition_list) +from jax._src.util import (safe_map, split_list, partition_list) from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo import numpy as np @@ -832,12 +831,12 @@ def _cond_lowering(ctx, index, *args, branches, linear): # captures. case_op = hlo.CaseOp(flat_output_types, index=index, num_branches=len(branches)) - name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond') + name_stack = ctx.module_context.name_stack.extend('cond') for i, jaxpr in enumerate(branches): branch = case_op.regions[i].blocks.append() with ir.InsertionPoint(branch): sub_ctx = ctx.module_context.replace( - name_stack=extend_name_stack(name_stack, f'branch_{i}_fun')) + name_stack=name_stack.extend(f'branch_{i}_fun')) out_vals, tokens_out = mlir.jaxpr_subcomp( sub_ctx, jaxpr.jaxpr, tokens_in, map(mlir.ir_constants, jaxpr.consts), diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 7038ac1bd..b99dddac5 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -46,7 +46,6 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.ufuncs import logaddexp from jax._src.traceback_util import api_boundary from jax._src.util import ( - extend_name_stack, partition_list, safe_map, safe_zip, @@ -1480,7 +1479,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts, # Loop condition cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types) - name_stack = extend_name_stack(ctx.module_context.name_stack, 'while') + name_stack = ctx.module_context.name_stack.extend('while') with ir.InsertionPoint(cond_block): flat_cond_args = [ cond_block.arguments[i] for i in range(len(flat_loop_carry_types)) @@ -1489,8 +1488,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts, # Remove tokens from cond args cond_args = cond_args[num_tokens:] x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts]) - cond_ctx = ctx.module_context.replace( - name_stack=extend_name_stack(name_stack, 'cond')) + cond_ctx = ctx.module_context.replace(name_stack=name_stack.extend('cond')) ((pred,),), _ = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(), _map(mlir.ir_constants, cond_jaxpr.consts), *(x + z), dim_var_values=ctx.dim_var_values) @@ -1521,15 +1519,14 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts, token_args, body_args = util.split_list(body_args, [num_tokens]) tokens_in = mlir.TokenSet(zip(body_effects, token_args)) x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts]) - body_ctx = ctx.module_context.replace( - name_stack=extend_name_stack(name_stack, 'body')) + body_ctx = ctx.module_context.replace(name_stack=name_stack.extend('body')) new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr, tokens_in, _map(mlir.ir_constants, body_jaxpr.consts), *(y + z), dim_var_values=ctx.dim_var_values) out_tokens = [tokens_out.get(eff) for eff in body_effects] if batched: body_pred_ctx = ctx.module_context.replace( - name_stack=extend_name_stack(name_stack, 'body_pred')) + name_stack=name_stack.extend('body_pred')) ((body_pred,),), _ = mlir.jaxpr_subcomp( body_pred_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(), _map(mlir.ir_constants, cond_jaxpr.consts), diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ead33d91b..7efeddb16 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3638,7 +3638,8 @@ def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions): ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] reducer = op.regions[0].blocks.append(*(ir_types + ir_types)) with ir.InsertionPoint(reducer): - reducer_ctx = ctx.module_context.replace(name_stack=util.new_name_stack()) + reducer_ctx = ctx.module_context.replace( + name_stack=source_info_util.new_name_stack()) if jaxpr.effects: raise NotImplementedError('Cannot lower effectful `reduce`.') out_nodes, _ = mlir.jaxpr_subcomp(reducer_ctx, jaxpr, mlir.TokenSet(), consts, diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 91755d557..f7401e634 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -25,6 +25,7 @@ from jax.interpreters import partial_eval as pe from jax._src import ad_util from jax._src import core from jax._src import dtypes +from jax._src import source_info_util from jax._src import util from jax._src.interpreters import ad from jax._src.interpreters import batching @@ -2016,7 +2017,8 @@ def _scatter_lower(ctx, operand, indices, updates, *, scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), aval_out.dtype)) update = op.update_computation.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(update): - update_ctx = ctx.module_context.replace(name_stack=util.new_name_stack()) + update_ctx = ctx.module_context.replace( + name_stack=source_info_util.new_name_stack()) if update_jaxpr.effects: raise NotImplementedError('Cannot lower effectful `scatter`.') out_nodes, _ = mlir.jaxpr_subcomp( diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 24fd2a481..436f9417c 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -47,7 +47,7 @@ from jax.interpreters import ad from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3, as_hashable_function, distributed_debug_log, tuple_insert, moveaxis, split_list, wrap_name, - merge_lists, partition_list, extend_name_stack) + merge_lists, partition_list) from jax import lax source_info_util.register_exclusion(__file__) @@ -1374,8 +1374,7 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes, # We in-line here rather than generating a Call HLO as in the xla_call # translation rule just because the extra tuple stuff is a pain. sub_ctx = ctx.module_context.replace( - name_stack=extend_name_stack(ctx.module_context.name_stack, - wrap_name(name, 'xmap'))) + name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap'))) if any(effects.ordered_effects.contains(eff) for eff in vectorized_jaxpr.effects): raise NotImplementedError('Cannot lower `xmap` with ordered effects.') @@ -1442,8 +1441,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes, # We in-line here rather than generating a Call HLO as in the xla_call # translation rule just because the extra tuple stuff is a pain. sub_ctx = ctx.module_context.replace( - name_stack=extend_name_stack(ctx.module_context.name_stack, - wrap_name(name, 'xmap'))) + name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap'))) if any(effects.ordered_effects.contains(eff) for eff in vectorized_jaxpr.effects): raise NotImplementedError('Cannot lower `xmap` with ordered effects.') @@ -1494,8 +1492,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes, # translation rule just because the extra tuple stuff is a pain. assert isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext) sub_ctx = ctx.module_context.replace( - name_stack=extend_name_stack(ctx.module_context.name_stack, - wrap_name(name, 'xmap')), + name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')), axis_context=ctx.module_context.axis_context.extend_manual(manual_mesh_axes)) if any(effects.ordered_effects.contains(eff) for eff in vectorized_jaxpr.effects): diff --git a/jax/_src/source_info_util.py b/jax/_src/source_info_util.py index b71114148..93d8e018a 100644 --- a/jax/_src/source_info_util.py +++ b/jax/_src/source_info_util.py @@ -102,6 +102,14 @@ class NameStack: scope = elem.wrap(scope) return '/'.join(scope) + +def new_name_stack(name: str = '') -> NameStack: + name_stack = NameStack() + if name: + name_stack = name_stack.extend(name) + return name_stack + + class SourceInfo(NamedTuple): traceback: Optional[Traceback] name_stack: NameStack diff --git a/jax/_src/util.py b/jax/_src/util.py index 6f1744e23..efc31eb1a 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -330,18 +330,6 @@ def get_module_functions(module): def wrap_name(name, transform_name): return transform_name + '(' + name + ')' -def new_name_stack(name: str = ''): - from jax._src import source_info_util - name_stack = source_info_util.NameStack() - if name: - name_stack = name_stack.extend(name) - return name_stack - -def extend_name_stack(stack, name: str): - from jax._src import source_info_util - assert isinstance(stack, source_info_util.NameStack), stack - return stack.extend(name) - def canonicalize_axis(axis, num_dims) -> int: """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" axis = operator.index(axis) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 38dabfa80..f1f96e0c6 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -95,7 +95,6 @@ from jax._src.interpreters.pxla import ( mesh_sharding_specs as mesh_sharding_specs, multi_host_supported_collectives as multi_host_supported_collectives, new_mesh_sharding_specs as new_mesh_sharding_specs, - new_name_stack as new_name_stack, op_sharding_to_indices as op_sharding_to_indices, parallel_callable as parallel_callable, partitioned_sharding_spec as partitioned_sharding_spec,