mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Remove circular dependency between source_info_util and util.
Move util.new_name_stack into source_info_util. Replace uses of util.extend_name_stack with stack.extend(). PiperOrigin-RevId: 512685810
This commit is contained in:
parent
bcf378f6b4
commit
148774587a
@ -66,8 +66,8 @@ from jax._src.sharding import PmapSharding
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import broadcast_prefix, _generate_key_paths
|
||||
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
|
||||
extend_name_stack, new_name_stack, wrap_name, cache,
|
||||
wraps, HashableFunction, weakref_lru_cache)
|
||||
wrap_name, cache, wraps, HashableFunction,
|
||||
weakref_lru_cache)
|
||||
|
||||
|
||||
# Unused imports to be exported
|
||||
@ -1083,7 +1083,8 @@ def xla_computation(fun: Callable,
|
||||
backend_or_name=backend,
|
||||
platform=platform,
|
||||
axis_context=mlir.ReplicaAxisContext(axis_env_),
|
||||
name_stack=new_name_stack(wrap_name(fun_name, "xla_computation")),
|
||||
name_stack=source_info_util.new_name_stack(
|
||||
wrap_name(fun_name, "xla_computation")),
|
||||
donated_args=donated_invars,
|
||||
arg_shardings=(None if in_parts_flat is None else map(
|
||||
xla.sharding_to_proto, in_parts_flat)),
|
||||
|
@ -499,7 +499,7 @@ def lower_xla_callable(
|
||||
# pass long arg lists as tuple for TPU
|
||||
tuple_args = should_tuple_args(len(abstract_args), backend.platform)
|
||||
axis_env = xla.AxisEnv(nreps, (), ())
|
||||
name_stack = util.new_name_stack(util.wrap_name(name, 'jit'))
|
||||
name_stack = source_info_util.new_name_stack(util.wrap_name(name, 'jit'))
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
closed_out_type = [
|
||||
(a.update(shape=tuple(pe.InDBIdx(d.val - len(consts))
|
||||
|
@ -1014,8 +1014,7 @@ def lower_jaxpr_to_fun(
|
||||
args.append(hlo.CreateTokenOp().results)
|
||||
else:
|
||||
args.append(arg)
|
||||
callee_name_stack = util.extend_name_stack(
|
||||
ctx.name_stack, util.wrap_name(name, api_name))
|
||||
callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name))
|
||||
out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
|
||||
jaxpr.jaxpr, tokens_in, map(ir_constants, jaxpr.consts),
|
||||
*args, dim_var_values=dim_var_values)
|
||||
|
@ -78,7 +78,7 @@ from jax._src.lib import pmap_lib
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
|
||||
new_name_stack, wrap_name, assert_unreachable,
|
||||
wrap_name, assert_unreachable,
|
||||
tuple_insert, tuple_delete, distributed_debug_log,
|
||||
unzip2, HashableFunction)
|
||||
|
||||
@ -1478,7 +1478,7 @@ def lower_parallel_callable(
|
||||
|
||||
axis_env = xla.AxisEnv(
|
||||
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
|
||||
name_stack = new_name_stack(wrap_name(name, 'pmap'))
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap'))
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
replicated_args = [axis is None for axis in in_axes]
|
||||
tuple_args = dispatch.should_tuple_args(len(shards.global_sharded_avals),
|
||||
@ -2311,8 +2311,8 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
axis_context=mlir.ReplicaAxisContext(new_env),
|
||||
name_stack=util.extend_name_stack(ctx.module_context.name_stack,
|
||||
util.wrap_name(name, 'pmap')))
|
||||
name_stack=ctx.module_context.name_stack.extend(
|
||||
util.wrap_name(name, 'pmap')))
|
||||
sharded_outs, _ = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, mlir.TokenSet(), (),
|
||||
*in_nodes_sharded,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
@ -2852,7 +2852,7 @@ def lower_sharding_computation(
|
||||
the singleton _UNSPECIFIED to all out_avals.
|
||||
"""
|
||||
# 1. Trace to jaxpr and preprocess/verify it
|
||||
name_stack = new_name_stack(wrap_name(fun_name, api_name))
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
||||
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
|
||||
"in {elapsed_time} sec",
|
||||
@ -3068,7 +3068,7 @@ def lower_mesh_computation(
|
||||
in_is_global: Sequence[bool]) -> MeshComputation:
|
||||
assert not mesh.empty
|
||||
backend = xb.get_device_backend(mesh.devices.flat[0])
|
||||
name_stack = new_name_stack(wrap_name(fun_name, api_name))
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(fun_name, api_name))
|
||||
|
||||
auto_spmd_lowering = check_if_any_auto(in_shardings + out_shardings) # type: ignore
|
||||
|
||||
|
@ -37,8 +37,7 @@ from jax._src import source_info_util
|
||||
from jax._src.abstract_arrays import numpy_scalar_types
|
||||
from jax._src.core import ConcreteArray, ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.util import (prod, new_name_stack, safe_zip, safe_map,
|
||||
partition_list)
|
||||
from jax._src.util import (prod, safe_zip, safe_map, partition_list)
|
||||
|
||||
from jax._src.typing import Shape
|
||||
|
||||
@ -251,7 +250,7 @@ def primitive_subcomputation(platform: str, axis_env: 'AxisEnv',
|
||||
rule = _translations[prim]
|
||||
|
||||
ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env,
|
||||
name_stack=new_name_stack())
|
||||
name_stack=source_info_util.new_name_stack())
|
||||
ans = rule(ctx, avals_in, avals_out, *xla_args, **params)
|
||||
|
||||
if prim.multiple_results:
|
||||
|
@ -39,8 +39,7 @@ from jax._src import state
|
||||
from jax._src.core import ConcreteArray, raise_to_shaped, replace_jaxpr_effects
|
||||
from jax._src.lax import lax
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (safe_map, extend_name_stack, split_list,
|
||||
partition_list)
|
||||
from jax._src.util import (safe_map, split_list, partition_list)
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
import numpy as np
|
||||
@ -832,12 +831,12 @@ def _cond_lowering(ctx, index, *args, branches, linear):
|
||||
# captures.
|
||||
case_op = hlo.CaseOp(flat_output_types, index=index,
|
||||
num_branches=len(branches))
|
||||
name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
|
||||
name_stack = ctx.module_context.name_stack.extend('cond')
|
||||
for i, jaxpr in enumerate(branches):
|
||||
branch = case_op.regions[i].blocks.append()
|
||||
with ir.InsertionPoint(branch):
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=extend_name_stack(name_stack, f'branch_{i}_fun'))
|
||||
name_stack=name_stack.extend(f'branch_{i}_fun'))
|
||||
out_vals, tokens_out = mlir.jaxpr_subcomp(
|
||||
sub_ctx, jaxpr.jaxpr, tokens_in,
|
||||
map(mlir.ir_constants, jaxpr.consts),
|
||||
|
@ -46,7 +46,6 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy.ufuncs import logaddexp
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (
|
||||
extend_name_stack,
|
||||
partition_list,
|
||||
safe_map,
|
||||
safe_zip,
|
||||
@ -1480,7 +1479,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
|
||||
# Loop condition
|
||||
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
|
||||
name_stack = extend_name_stack(ctx.module_context.name_stack, 'while')
|
||||
name_stack = ctx.module_context.name_stack.extend('while')
|
||||
with ir.InsertionPoint(cond_block):
|
||||
flat_cond_args = [
|
||||
cond_block.arguments[i] for i in range(len(flat_loop_carry_types))
|
||||
@ -1489,8 +1488,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
# Remove tokens from cond args
|
||||
cond_args = cond_args[num_tokens:]
|
||||
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
|
||||
cond_ctx = ctx.module_context.replace(
|
||||
name_stack=extend_name_stack(name_stack, 'cond'))
|
||||
cond_ctx = ctx.module_context.replace(name_stack=name_stack.extend('cond'))
|
||||
((pred,),), _ = mlir.jaxpr_subcomp(cond_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
|
||||
_map(mlir.ir_constants, cond_jaxpr.consts),
|
||||
*(x + z), dim_var_values=ctx.dim_var_values)
|
||||
@ -1521,15 +1519,14 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
token_args, body_args = util.split_list(body_args, [num_tokens])
|
||||
tokens_in = mlir.TokenSet(zip(body_effects, token_args))
|
||||
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
|
||||
body_ctx = ctx.module_context.replace(
|
||||
name_stack=extend_name_stack(name_stack, 'body'))
|
||||
body_ctx = ctx.module_context.replace(name_stack=name_stack.extend('body'))
|
||||
new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
|
||||
tokens_in, _map(mlir.ir_constants, body_jaxpr.consts),
|
||||
*(y + z), dim_var_values=ctx.dim_var_values)
|
||||
out_tokens = [tokens_out.get(eff) for eff in body_effects]
|
||||
if batched:
|
||||
body_pred_ctx = ctx.module_context.replace(
|
||||
name_stack=extend_name_stack(name_stack, 'body_pred'))
|
||||
name_stack=name_stack.extend('body_pred'))
|
||||
((body_pred,),), _ = mlir.jaxpr_subcomp(
|
||||
body_pred_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
|
||||
_map(mlir.ir_constants, cond_jaxpr.consts),
|
||||
|
@ -3638,7 +3638,8 @@ def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions):
|
||||
ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
|
||||
reducer = op.regions[0].blocks.append(*(ir_types + ir_types))
|
||||
with ir.InsertionPoint(reducer):
|
||||
reducer_ctx = ctx.module_context.replace(name_stack=util.new_name_stack())
|
||||
reducer_ctx = ctx.module_context.replace(
|
||||
name_stack=source_info_util.new_name_stack())
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Cannot lower effectful `reduce`.')
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(reducer_ctx, jaxpr, mlir.TokenSet(), consts,
|
||||
|
@ -25,6 +25,7 @@ from jax.interpreters import partial_eval as pe
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -2016,7 +2017,8 @@ def _scatter_lower(ctx, operand, indices, updates, *,
|
||||
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), aval_out.dtype))
|
||||
update = op.update_computation.blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(update):
|
||||
update_ctx = ctx.module_context.replace(name_stack=util.new_name_stack())
|
||||
update_ctx = ctx.module_context.replace(
|
||||
name_stack=source_info_util.new_name_stack())
|
||||
if update_jaxpr.effects:
|
||||
raise NotImplementedError('Cannot lower effectful `scatter`.')
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(
|
||||
|
@ -47,7 +47,7 @@ from jax.interpreters import ad
|
||||
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3,
|
||||
as_hashable_function, distributed_debug_log,
|
||||
tuple_insert, moveaxis, split_list, wrap_name,
|
||||
merge_lists, partition_list, extend_name_stack)
|
||||
merge_lists, partition_list)
|
||||
from jax import lax
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
@ -1374,8 +1374,7 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=extend_name_stack(ctx.module_context.name_stack,
|
||||
wrap_name(name, 'xmap')))
|
||||
name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')))
|
||||
if any(effects.ordered_effects.contains(eff) for eff
|
||||
in vectorized_jaxpr.effects):
|
||||
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
|
||||
@ -1442,8 +1441,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=extend_name_stack(ctx.module_context.name_stack,
|
||||
wrap_name(name, 'xmap')))
|
||||
name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')))
|
||||
if any(effects.ordered_effects.contains(eff) for eff
|
||||
in vectorized_jaxpr.effects):
|
||||
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
|
||||
@ -1494,8 +1492,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
assert isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext)
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=extend_name_stack(ctx.module_context.name_stack,
|
||||
wrap_name(name, 'xmap')),
|
||||
name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')),
|
||||
axis_context=ctx.module_context.axis_context.extend_manual(manual_mesh_axes))
|
||||
if any(effects.ordered_effects.contains(eff) for eff
|
||||
in vectorized_jaxpr.effects):
|
||||
|
@ -102,6 +102,14 @@ class NameStack:
|
||||
scope = elem.wrap(scope)
|
||||
return '/'.join(scope)
|
||||
|
||||
|
||||
def new_name_stack(name: str = '') -> NameStack:
|
||||
name_stack = NameStack()
|
||||
if name:
|
||||
name_stack = name_stack.extend(name)
|
||||
return name_stack
|
||||
|
||||
|
||||
class SourceInfo(NamedTuple):
|
||||
traceback: Optional[Traceback]
|
||||
name_stack: NameStack
|
||||
|
@ -330,18 +330,6 @@ def get_module_functions(module):
|
||||
def wrap_name(name, transform_name):
|
||||
return transform_name + '(' + name + ')'
|
||||
|
||||
def new_name_stack(name: str = ''):
|
||||
from jax._src import source_info_util
|
||||
name_stack = source_info_util.NameStack()
|
||||
if name:
|
||||
name_stack = name_stack.extend(name)
|
||||
return name_stack
|
||||
|
||||
def extend_name_stack(stack, name: str):
|
||||
from jax._src import source_info_util
|
||||
assert isinstance(stack, source_info_util.NameStack), stack
|
||||
return stack.extend(name)
|
||||
|
||||
def canonicalize_axis(axis, num_dims) -> int:
|
||||
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
|
||||
axis = operator.index(axis)
|
||||
|
@ -95,7 +95,6 @@ from jax._src.interpreters.pxla import (
|
||||
mesh_sharding_specs as mesh_sharding_specs,
|
||||
multi_host_supported_collectives as multi_host_supported_collectives,
|
||||
new_mesh_sharding_specs as new_mesh_sharding_specs,
|
||||
new_name_stack as new_name_stack,
|
||||
op_sharding_to_indices as op_sharding_to_indices,
|
||||
parallel_callable as parallel_callable,
|
||||
partitioned_sharding_spec as partitioned_sharding_spec,
|
||||
|
Loading…
x
Reference in New Issue
Block a user