Remove circular dependency between source_info_util and util.

Move util.new_name_stack into source_info_util. Replace uses of util.extend_name_stack with stack.extend().

PiperOrigin-RevId: 512685810
This commit is contained in:
Peter Hawkins 2023-02-27 11:37:10 -08:00 committed by jax authors
parent bcf378f6b4
commit 148774587a
13 changed files with 38 additions and 48 deletions

View File

@ -66,8 +66,8 @@ from jax._src.sharding import PmapSharding
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import broadcast_prefix, _generate_key_paths
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
extend_name_stack, new_name_stack, wrap_name, cache,
wraps, HashableFunction, weakref_lru_cache)
wrap_name, cache, wraps, HashableFunction,
weakref_lru_cache)
# Unused imports to be exported
@ -1083,7 +1083,8 @@ def xla_computation(fun: Callable,
backend_or_name=backend,
platform=platform,
axis_context=mlir.ReplicaAxisContext(axis_env_),
name_stack=new_name_stack(wrap_name(fun_name, "xla_computation")),
name_stack=source_info_util.new_name_stack(
wrap_name(fun_name, "xla_computation")),
donated_args=donated_invars,
arg_shardings=(None if in_parts_flat is None else map(
xla.sharding_to_proto, in_parts_flat)),

View File

@ -499,7 +499,7 @@ def lower_xla_callable(
# pass long arg lists as tuple for TPU
tuple_args = should_tuple_args(len(abstract_args), backend.platform)
axis_env = xla.AxisEnv(nreps, (), ())
name_stack = util.new_name_stack(util.wrap_name(name, 'jit'))
name_stack = source_info_util.new_name_stack(util.wrap_name(name, 'jit'))
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
closed_out_type = [
(a.update(shape=tuple(pe.InDBIdx(d.val - len(consts))

View File

@ -1014,8 +1014,7 @@ def lower_jaxpr_to_fun(
args.append(hlo.CreateTokenOp().results)
else:
args.append(arg)
callee_name_stack = util.extend_name_stack(
ctx.name_stack, util.wrap_name(name, api_name))
callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name))
out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
jaxpr.jaxpr, tokens_in, map(ir_constants, jaxpr.consts),
*args, dim_var_values=dim_var_values)

View File

@ -78,7 +78,7 @@ from jax._src.lib import pmap_lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
new_name_stack, wrap_name, assert_unreachable,
wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log,
unzip2, HashableFunction)
@ -1478,7 +1478,7 @@ def lower_parallel_callable(
axis_env = xla.AxisEnv(
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
name_stack = new_name_stack(wrap_name(name, 'pmap'))
name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap'))
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
replicated_args = [axis is None for axis in in_axes]
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
@ -2311,8 +2311,8 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
sub_ctx = ctx.module_context.replace(
axis_context=mlir.ReplicaAxisContext(new_env),
name_stack=util.extend_name_stack(ctx.module_context.name_stack,
util.wrap_name(name, 'pmap')))
name_stack=ctx.module_context.name_stack.extend(
util.wrap_name(name, 'pmap')))
sharded_outs, _ = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, mlir.TokenSet(), (),
*in_nodes_sharded,
dim_var_values=ctx.dim_var_values)
@ -2852,7 +2852,7 @@ def lower_sharding_computation(
the singleton _UNSPECIFIED to all out_avals.
"""
# 1. Trace to jaxpr and preprocess/verify it
name_stack = new_name_stack(wrap_name(fun_name, api_name))
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
"in {elapsed_time} sec",
@ -3068,7 +3068,7 @@ def lower_mesh_computation(
in_is_global: Sequence[bool]) -> MeshComputation:
assert not mesh.empty
backend = xb.get_device_backend(mesh.devices.flat[0])
name_stack = new_name_stack(wrap_name(fun_name, api_name))
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
auto_spmd_lowering = check_if_any_auto(in_shardings + out_shardings) # type: ignore

View File

@ -37,8 +37,7 @@ from jax._src import source_info_util
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.interpreters import ad
from jax._src.util import (prod, new_name_stack, safe_zip, safe_map,
partition_list)
from jax._src.util import (prod, safe_zip, safe_map, partition_list)
from jax._src.typing import Shape
@ -251,7 +250,7 @@ def primitive_subcomputation(platform: str, axis_env: 'AxisEnv',
rule = _translations[prim]
ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env,
name_stack=new_name_stack())
name_stack=source_info_util.new_name_stack())
ans = rule(ctx, avals_in, avals_out, *xla_args, **params)
if prim.multiple_results:

View File

@ -39,8 +39,7 @@ from jax._src import state
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
from jax._src.lax import lax
from jax._src.traceback_util import api_boundary
from jax._src.util import (safe_map, extend_name_stack, split_list,
partition_list)
from jax._src.util import (safe_map, split_list, partition_list)
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
import numpy as np
@ -832,12 +831,12 @@ def _cond_lowering(ctx, index, *args, branches, linear):
# captures.
case_op = hlo.CaseOp(flat_output_types, index=index,
num_branches=len(branches))
name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
name_stack = ctx.module_context.name_stack.extend('cond')
for i, jaxpr in enumerate(branches):
branch = case_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
sub_ctx = ctx.module_context.replace(
name_stack=extend_name_stack(name_stack, f'branch_{i}_fun'))
name_stack=name_stack.extend(f'branch_{i}_fun'))
out_vals, tokens_out = mlir.jaxpr_subcomp(
sub_ctx, jaxpr.jaxpr, tokens_in,
map(mlir.ir_constants, jaxpr.consts),

View File

@ -46,7 +46,6 @@ from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
from jax._src.traceback_util import api_boundary
from jax._src.util import (
extend_name_stack,
partition_list,
safe_map,
safe_zip,
@ -1480,7 +1479,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
# Loop condition
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
name_stack = extend_name_stack(ctx.module_context.name_stack, 'while')
name_stack = ctx.module_context.name_stack.extend('while')
with ir.InsertionPoint(cond_block):
flat_cond_args = [
cond_block.arguments[i] for i in range(len(flat_loop_carry_types))
@ -1489,8 +1488,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
# Remove tokens from cond args
cond_args = cond_args[num_tokens:]
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
cond_ctx = ctx.module_context.replace(
name_stack=extend_name_stack(name_stack, 'cond'))
cond_ctx = ctx.module_context.replace(name_stack=name_stack.extend('cond'))
((pred,),), _ = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
_map(mlir.ir_constants, cond_jaxpr.consts),
*(x + z), dim_var_values=ctx.dim_var_values)
@ -1521,15 +1519,14 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
token_args, body_args = util.split_list(body_args, [num_tokens])
tokens_in = mlir.TokenSet(zip(body_effects, token_args))
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
body_ctx = ctx.module_context.replace(
name_stack=extend_name_stack(name_stack, 'body'))
body_ctx = ctx.module_context.replace(name_stack=name_stack.extend('body'))
new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
tokens_in, _map(mlir.ir_constants, body_jaxpr.consts),
*(y + z), dim_var_values=ctx.dim_var_values)
out_tokens = [tokens_out.get(eff) for eff in body_effects]
if batched:
body_pred_ctx = ctx.module_context.replace(
name_stack=extend_name_stack(name_stack, 'body_pred'))
name_stack=name_stack.extend('body_pred'))
((body_pred,),), _ = mlir.jaxpr_subcomp(
body_pred_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
_map(mlir.ir_constants, cond_jaxpr.consts),

View File

@ -3638,7 +3638,8 @@ def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions):
ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
reducer = op.regions[0].blocks.append(*(ir_types + ir_types))
with ir.InsertionPoint(reducer):
reducer_ctx = ctx.module_context.replace(name_stack=util.new_name_stack())
reducer_ctx = ctx.module_context.replace(
name_stack=source_info_util.new_name_stack())
if jaxpr.effects:
raise NotImplementedError('Cannot lower effectful `reduce`.')
out_nodes, _ = mlir.jaxpr_subcomp(reducer_ctx, jaxpr, mlir.TokenSet(), consts,

View File

@ -25,6 +25,7 @@ from jax.interpreters import partial_eval as pe
from jax._src import ad_util
from jax._src import core
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -2016,7 +2017,8 @@ def _scatter_lower(ctx, operand, indices, updates, *,
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), aval_out.dtype))
update = op.update_computation.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(update):
update_ctx = ctx.module_context.replace(name_stack=util.new_name_stack())
update_ctx = ctx.module_context.replace(
name_stack=source_info_util.new_name_stack())
if update_jaxpr.effects:
raise NotImplementedError('Cannot lower effectful `scatter`.')
out_nodes, _ = mlir.jaxpr_subcomp(

View File

@ -47,7 +47,7 @@ from jax.interpreters import ad
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3,
as_hashable_function, distributed_debug_log,
tuple_insert, moveaxis, split_list, wrap_name,
merge_lists, partition_list, extend_name_stack)
merge_lists, partition_list)
from jax import lax
source_info_util.register_exclusion(__file__)
@ -1374,8 +1374,7 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
sub_ctx = ctx.module_context.replace(
name_stack=extend_name_stack(ctx.module_context.name_stack,
wrap_name(name, 'xmap')))
name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')))
if any(effects.ordered_effects.contains(eff) for eff
in vectorized_jaxpr.effects):
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
@ -1442,8 +1441,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
sub_ctx = ctx.module_context.replace(
name_stack=extend_name_stack(ctx.module_context.name_stack,
wrap_name(name, 'xmap')))
name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')))
if any(effects.ordered_effects.contains(eff) for eff
in vectorized_jaxpr.effects):
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
@ -1494,8 +1492,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
# translation rule just because the extra tuple stuff is a pain.
assert isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext)
sub_ctx = ctx.module_context.replace(
name_stack=extend_name_stack(ctx.module_context.name_stack,
wrap_name(name, 'xmap')),
name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')),
axis_context=ctx.module_context.axis_context.extend_manual(manual_mesh_axes))
if any(effects.ordered_effects.contains(eff) for eff
in vectorized_jaxpr.effects):

View File

@ -102,6 +102,14 @@ class NameStack:
scope = elem.wrap(scope)
return '/'.join(scope)
def new_name_stack(name: str = '') -> NameStack:
name_stack = NameStack()
if name:
name_stack = name_stack.extend(name)
return name_stack
class SourceInfo(NamedTuple):
traceback: Optional[Traceback]
name_stack: NameStack

View File

@ -330,18 +330,6 @@ def get_module_functions(module):
def wrap_name(name, transform_name):
return transform_name + '(' + name + ')'
def new_name_stack(name: str = ''):
from jax._src import source_info_util
name_stack = source_info_util.NameStack()
if name:
name_stack = name_stack.extend(name)
return name_stack
def extend_name_stack(stack, name: str):
from jax._src import source_info_util
assert isinstance(stack, source_info_util.NameStack), stack
return stack.extend(name)
def canonicalize_axis(axis, num_dims) -> int:
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
axis = operator.index(axis)

View File

@ -95,7 +95,6 @@ from jax._src.interpreters.pxla import (
mesh_sharding_specs as mesh_sharding_specs,
multi_host_supported_collectives as multi_host_supported_collectives,
new_mesh_sharding_specs as new_mesh_sharding_specs,
new_name_stack as new_name_stack,
op_sharding_to_indices as op_sharding_to_indices,
parallel_callable as parallel_callable,
partitioned_sharding_spec as partitioned_sharding_spec,