mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Split name_stack out of mlir.ModuleContext.
A unique name_stack is built for every equation, which means that we're constantly rebuilding ModuleContext objects, even though the lifetime of almost everything else (naturally) is the Module scope. Split name_stack into an object that is threaded separately, including as part of mlir.LoweringRuleContext. PiperOrigin-RevId: 608594374
This commit is contained in:
parent
2165611584
commit
f1ea67117e
@ -437,8 +437,8 @@ def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_thunk,
|
||||
args_ = map(mlir.wrap_singleton_ir_values, args)
|
||||
consts = mlir._ir_consts(call_jaxpr.consts)
|
||||
out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr,
|
||||
ctx.tokens_in, consts, *args_,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
ctx.name_stack, ctx.tokens_in, consts,
|
||||
*args_, dim_var_values=ctx.dim_var_values)
|
||||
ctx.set_tokens_out(tokens)
|
||||
return out
|
||||
mlir.register_lowering(custom_jvp_call_p, _custom_jvp_call_mlir_translation)
|
||||
|
@ -340,24 +340,23 @@ register_constant_handler(core.Token, _token_constant_handler)
|
||||
# Source locations
|
||||
|
||||
def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str:
|
||||
if file_name in caches.canonical_name_cache:
|
||||
return caches.canonical_name_cache[file_name]
|
||||
canonical_file_name = caches.canonical_name_cache.get(file_name, None)
|
||||
if canonical_file_name is not None:
|
||||
return canonical_file_name
|
||||
|
||||
source_file = file_name
|
||||
pattern = config.hlo_source_file_canonicalization_regex.value
|
||||
if pattern:
|
||||
source_file = re.sub(pattern, '', source_file)
|
||||
|
||||
caches.canonical_name_cache[file_name] = source_file
|
||||
return source_file
|
||||
file_name = re.sub(pattern, '', file_name)
|
||||
caches.canonical_name_cache[file_name] = file_name
|
||||
return file_name
|
||||
|
||||
def _is_user_file(ctx: ModuleContext, file_name: str) -> bool:
|
||||
if file_name in ctx.traceback_caches.is_user_file_cache:
|
||||
return ctx.traceback_caches.is_user_file_cache[file_name]
|
||||
|
||||
result = source_info_util.is_user_filename(file_name)
|
||||
ctx.traceback_caches.is_user_file_cache[file_name] = result
|
||||
return result
|
||||
is_user = ctx.traceback_caches.is_user_file_cache.get(file_name, None)
|
||||
if is_user is not None:
|
||||
return is_user
|
||||
out = source_info_util.is_user_filename(file_name)
|
||||
ctx.traceback_caches.is_user_file_cache[file_name] = out
|
||||
return out
|
||||
|
||||
def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
|
||||
"""Converts a full traceback to a callsite() MLIR location."""
|
||||
@ -386,12 +385,12 @@ def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location:
|
||||
if len(frame_locs) >= frames_limit:
|
||||
break
|
||||
|
||||
if len(frame_locs) == 0:
|
||||
n = len(frame_locs)
|
||||
if n == 0:
|
||||
return ir.Location.unknown()
|
||||
elif n == 1:
|
||||
return frame_locs[0]
|
||||
else:
|
||||
if len(frame_locs) == 1:
|
||||
return frame_locs[0]
|
||||
|
||||
return ir.Location.callsite(frame_locs[0], frame_locs[1:])
|
||||
|
||||
def _source_info_to_location(
|
||||
@ -589,7 +588,6 @@ class ModuleContext:
|
||||
backend_or_name: str | xb.XlaBackend | None
|
||||
platforms: Sequence[str]
|
||||
axis_context: AxisContext
|
||||
name_stack: source_info_util.NameStack
|
||||
keepalives: list[Any]
|
||||
channel_iterator: Iterator[int]
|
||||
host_callbacks: list[Any]
|
||||
@ -614,7 +612,6 @@ class ModuleContext:
|
||||
backend_or_name: str | xb.XlaBackend | None,
|
||||
platforms: Sequence[str],
|
||||
axis_context: AxisContext,
|
||||
name_stack: source_info_util.NameStack,
|
||||
keepalives: list[Any],
|
||||
channel_iterator: Iterator[int],
|
||||
host_callbacks: list[Any],
|
||||
@ -635,7 +632,6 @@ class ModuleContext:
|
||||
self.backend_or_name = backend_or_name
|
||||
self.platforms = platforms
|
||||
self.axis_context = axis_context
|
||||
self.name_stack = name_stack
|
||||
self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
|
||||
else cached_primitive_lowerings)
|
||||
self.traceback_caches = (TracebackCaches() if traceback_caches is None
|
||||
@ -683,6 +679,7 @@ class ModuleContext:
|
||||
class LoweringRuleContext:
|
||||
"""Per-rule context information for MLIR lowering."""
|
||||
module_context: ModuleContext
|
||||
name_stack: source_info_util.NameStack
|
||||
primitive: core.Primitive | None
|
||||
avals_in: Sequence[core.AbstractValue]
|
||||
avals_out: Any # Usually Sequence[core.AbstractValue], but sometimes None.
|
||||
@ -947,7 +944,6 @@ def lower_jaxpr_to_module(
|
||||
|
||||
ctx = ModuleContext(backend_or_name=backend_or_name,
|
||||
platforms=platforms, axis_context=axis_context,
|
||||
name_stack=name_stack,
|
||||
keepalives=keepalives,
|
||||
channel_iterator=channel_iter,
|
||||
host_callbacks=host_callbacks,
|
||||
@ -964,7 +960,9 @@ def lower_jaxpr_to_module(
|
||||
attrs["mhlo.num_partitions"] = i32_attr(num_partitions)
|
||||
replace_tokens_with_dummy = lowering_parameters.replace_tokens_with_dummy
|
||||
lower_jaxpr_to_fun(
|
||||
ctx, "main", jaxpr, ordered_effects, public=True,
|
||||
ctx, "main", jaxpr, ordered_effects,
|
||||
name_stack=name_stack,
|
||||
public=True,
|
||||
create_tokens=replace_tokens_with_dummy,
|
||||
replace_tokens_with_dummy=replace_tokens_with_dummy,
|
||||
num_output_tokens=0,
|
||||
@ -1105,6 +1103,7 @@ def lower_jaxpr_to_fun(
|
||||
name: str,
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
effects: Sequence[core.Effect],
|
||||
name_stack: source_info_util.NameStack,
|
||||
*,
|
||||
create_tokens: bool = False,
|
||||
public: bool = False,
|
||||
@ -1376,7 +1375,7 @@ def lower_jaxpr_to_fun(
|
||||
dim_var_values, _, _ = util.split_list(flat_args, [num_dim_vars, num_tokens])
|
||||
# A lowering context just for function body entry/exit code.
|
||||
entry_lowering_ctx = LoweringRuleContext(
|
||||
module_context=ctx, primitive=None,
|
||||
module_context=ctx, name_stack=name_stack, primitive=None,
|
||||
avals_in=[], avals_out=None,
|
||||
tokens_in=TokenSet.create([]), tokens_out=None,
|
||||
axis_size_env=None, dim_var_values=dim_var_values)
|
||||
@ -1403,10 +1402,10 @@ def lower_jaxpr_to_fun(
|
||||
args.append([hlo.create_token()])
|
||||
else:
|
||||
args.append(arg)
|
||||
callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name))
|
||||
callee_name_stack = name_stack.extend(util.wrap_name(name, api_name))
|
||||
consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
|
||||
out_vals, tokens_out = jaxpr_subcomp(
|
||||
ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, tokens_in,
|
||||
ctx, jaxpr.jaxpr, callee_name_stack, tokens_in,
|
||||
consts, *args, dim_var_values=dim_var_values)
|
||||
outs = []
|
||||
if create_tokens:
|
||||
@ -1496,6 +1495,7 @@ def _emit_lowering_rule_as_fun(lowering_rule,
|
||||
return func_op
|
||||
|
||||
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
name_stack: source_info_util.NameStack,
|
||||
tokens: TokenSet,
|
||||
consts: Sequence[Sequence[ir.Value]],
|
||||
*args: Sequence[ir.Value],
|
||||
@ -1536,6 +1536,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
|
||||
env: dict[core.Var, tuple[ir.Value, ...]] = {}
|
||||
|
||||
assert isinstance(name_stack, source_info_util.NameStack), type(name_stack)
|
||||
assert len(args) == len(jaxpr.invars), (jaxpr, args)
|
||||
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
|
||||
assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts
|
||||
@ -1545,9 +1546,8 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
last_used = core.last_used(jaxpr)
|
||||
for eqn in jaxpr.eqns:
|
||||
in_nodes = map(read, eqn.invars)
|
||||
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
|
||||
source_info = eqn.source_info.replace(
|
||||
name_stack=ctx.name_stack + eqn.source_info.name_stack)
|
||||
name_stack=name_stack + eqn.source_info.name_stack)
|
||||
loc = _source_info_to_location(ctx, eqn.primitive, eqn.params, source_info)
|
||||
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
||||
override_rule = get_override_lowering_rule(eqn.primitive)
|
||||
@ -1569,12 +1569,12 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
elif eqn.primitive in xla._translations:
|
||||
default_rule = xla_fallback_lowering(eqn.primitive)
|
||||
|
||||
eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
|
||||
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
|
||||
tokens_in = tokens.subset(effects)
|
||||
avals_in = map(aval, eqn.invars)
|
||||
rule_ctx = LoweringRuleContext(
|
||||
module_context=eqn_ctx, primitive=eqn.primitive,
|
||||
module_context=ctx, primitive=eqn.primitive,
|
||||
name_stack=source_info.name_stack,
|
||||
avals_in=avals_in,
|
||||
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
|
||||
tokens_out=None, dim_var_values=dim_var_values)
|
||||
@ -1781,15 +1781,16 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
||||
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
|
||||
|
||||
out, tokens = jaxpr_subcomp(
|
||||
ctx.module_context, jaxpr, ctx.tokens_in, _ir_consts(consts),
|
||||
*map(wrap_singleton_ir_values, args), dim_var_values=ctx.dim_var_values)
|
||||
ctx.module_context, jaxpr, ctx.name_stack, ctx.tokens_in,
|
||||
_ir_consts(consts), *map(wrap_singleton_ir_values, args),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
ctx.set_tokens_out(tokens)
|
||||
return out
|
||||
|
||||
return f_lowered
|
||||
|
||||
|
||||
def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects,
|
||||
def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, name_stack,
|
||||
arg_names=None, result_names=None):
|
||||
if not call_jaxpr.consts and arg_names is result_names is None:
|
||||
# Cacheable.
|
||||
@ -1798,12 +1799,12 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects,
|
||||
func_op = ctx.cached_primitive_lowerings[key]
|
||||
except KeyError:
|
||||
func_op = lower_jaxpr_to_fun(
|
||||
ctx, fn_name, call_jaxpr, effects, arg_names=arg_names,
|
||||
ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names,
|
||||
result_names=result_names)
|
||||
ctx.cached_primitive_lowerings[key] = func_op
|
||||
else:
|
||||
func_op = lower_jaxpr_to_fun(
|
||||
ctx, fn_name, call_jaxpr, effects, arg_names=arg_names,
|
||||
ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names,
|
||||
result_names=result_names)
|
||||
return func_op
|
||||
|
||||
@ -1825,12 +1826,12 @@ def check_backend_matches(inner_backend: str | None,
|
||||
f"inner-jit backend specification {inner_backend}.")
|
||||
|
||||
|
||||
def _call_lowering(fn_name, stack_name, call_jaxpr, backend,
|
||||
ctx: ModuleContext, avals_in,
|
||||
avals_out, tokens_in, *args,
|
||||
dim_var_values: Sequence[ir.Value],
|
||||
arg_names=None, result_names=None):
|
||||
del stack_name, avals_in
|
||||
def call_lowering(fn_name, name_stack, call_jaxpr, backend,
|
||||
ctx: ModuleContext, avals_in,
|
||||
avals_out, tokens_in, *args,
|
||||
dim_var_values: Sequence[ir.Value],
|
||||
arg_names=None, result_names=None):
|
||||
del avals_in
|
||||
if isinstance(call_jaxpr, core.Jaxpr):
|
||||
call_jaxpr = pe.close_jaxpr(call_jaxpr)
|
||||
check_backend_matches(backend, ctx.platforms)
|
||||
@ -1839,7 +1840,7 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend,
|
||||
output_types = [token_type()] * len(effects) + output_types
|
||||
flat_output_types = util.flatten(output_types)
|
||||
symbol_name = _lower_jaxpr_to_fun_cached(
|
||||
ctx, fn_name, call_jaxpr, effects, arg_names=arg_names,
|
||||
ctx, fn_name, call_jaxpr, effects, name_stack, arg_names=arg_names,
|
||||
result_names=result_names).name.value
|
||||
tokens = [tokens_in.get(eff) for eff in effects]
|
||||
args = (*dim_var_values, *tokens, *args)
|
||||
@ -1853,8 +1854,8 @@ def _call_lowering(fn_name, stack_name, call_jaxpr, backend,
|
||||
|
||||
def core_call_lowering(ctx: LoweringRuleContext,
|
||||
*args, name, backend=None, call_jaxpr):
|
||||
out_nodes, tokens = _call_lowering(
|
||||
name, name, call_jaxpr, backend, ctx.module_context,
|
||||
out_nodes, tokens = call_lowering(
|
||||
name, ctx.name_stack, call_jaxpr, backend, ctx.module_context,
|
||||
ctx.avals_in, ctx.avals_out, ctx.tokens_in, *args,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
ctx.set_tokens_out(tokens)
|
||||
|
@ -1433,12 +1433,12 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
|
||||
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
axis_context=sharding_impls.ReplicaAxisContext(new_env),
|
||||
name_stack=ctx.module_context.name_stack.extend(
|
||||
util.wrap_name(name, 'pmap')))
|
||||
sharded_outs, _ = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, mlir.TokenSet(), (),
|
||||
*in_nodes_sharded,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
axis_context=sharding_impls.ReplicaAxisContext(new_env))
|
||||
sharded_outs, _ = mlir.jaxpr_subcomp(
|
||||
sub_ctx, call_jaxpr,
|
||||
ctx.name_stack.extend(util.wrap_name(name, 'pmap')),
|
||||
mlir.TokenSet(), (), *in_nodes_sharded,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
out_avals = [v.aval for v in call_jaxpr.outvars]
|
||||
outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard)
|
||||
for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)]
|
||||
|
@ -847,16 +847,14 @@ def _cond_lowering(ctx, index, *args, branches, linear):
|
||||
# captures.
|
||||
case_op = hlo.CaseOp(flat_output_types, index=index,
|
||||
num_branches=len(branches))
|
||||
name_stack = ctx.module_context.name_stack.extend('cond')
|
||||
name_stack = ctx.name_stack.extend('cond')
|
||||
for i, jaxpr in enumerate(branches):
|
||||
branch = case_op.regions[i].blocks.append()
|
||||
with ir.InsertionPoint(branch):
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=name_stack.extend(f'branch_{i}_fun'))
|
||||
consts = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
|
||||
out_vals, tokens_out = mlir.jaxpr_subcomp(
|
||||
sub_ctx, jaxpr.jaxpr, tokens_in,
|
||||
consts, *map(mlir.wrap_singleton_ir_values, args),
|
||||
ctx.module_context, jaxpr.jaxpr, name_stack.extend(f'branch_{i}_fun'),
|
||||
tokens_in, consts, *map(mlir.wrap_singleton_ir_values, args),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
|
||||
out_vals = [*out_tokens, *out_vals]
|
||||
|
@ -1662,7 +1662,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
|
||||
# Loop condition
|
||||
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
|
||||
name_stack = ctx.module_context.name_stack.extend('while')
|
||||
name_stack = ctx.name_stack.extend('while')
|
||||
with ir.InsertionPoint(cond_block):
|
||||
flat_cond_args = [
|
||||
cond_block.arguments[i] for i in range(len(flat_loop_carry_types))
|
||||
@ -1671,13 +1671,14 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
# Remove tokens from cond args
|
||||
cond_args = cond_args[num_tokens:]
|
||||
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
|
||||
cond_ctx = ctx.module_context.replace(name_stack=name_stack.extend('cond'))
|
||||
cond_consts = [
|
||||
mlir.ir_constants(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts
|
||||
]
|
||||
cond_name_stack = name_stack.extend('cond')
|
||||
((pred,),), _ = mlir.jaxpr_subcomp(
|
||||
cond_ctx,
|
||||
ctx.module_context,
|
||||
cond_jaxpr.jaxpr,
|
||||
cond_name_stack,
|
||||
mlir.TokenSet(),
|
||||
cond_consts,
|
||||
*(x + z),
|
||||
@ -1686,6 +1687,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
if batched:
|
||||
pred_ctx = mlir.LoweringRuleContext(
|
||||
module_context=ctx.module_context,
|
||||
name_stack=cond_name_stack,
|
||||
primitive=None,
|
||||
avals_in=[pred_aval],
|
||||
avals_out=[pred_aval.update(shape=())],
|
||||
@ -1710,20 +1712,21 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
token_args, body_args = util.split_list(body_args, [num_tokens])
|
||||
tokens_in = mlir.TokenSet(zip(body_effects, token_args))
|
||||
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
|
||||
body_ctx = ctx.module_context.replace(name_stack=name_stack.extend('body'))
|
||||
body_name_stack = name_stack.extend('body')
|
||||
body_consts = [mlir.ir_constants(xla.canonicalize_dtype(x))
|
||||
for x in body_jaxpr.consts]
|
||||
new_z, tokens_out = mlir.jaxpr_subcomp(body_ctx, body_jaxpr.jaxpr,
|
||||
new_z, tokens_out = mlir.jaxpr_subcomp(
|
||||
ctx.module_context, body_jaxpr.jaxpr, body_name_stack,
|
||||
tokens_in, body_consts, *(y + z), dim_var_values=ctx.dim_var_values)
|
||||
out_tokens = [tokens_out.get(eff) for eff in body_effects]
|
||||
if batched:
|
||||
body_pred_ctx = ctx.module_context.replace(
|
||||
name_stack=name_stack.extend('body_pred'))
|
||||
body_pred_name_stack = name_stack.extend('body_pred')
|
||||
cond_consts = [mlir.ir_constants(xla.canonicalize_dtype(x))
|
||||
for x in cond_jaxpr.consts]
|
||||
((body_pred,),), _ = mlir.jaxpr_subcomp(
|
||||
body_pred_ctx, cond_jaxpr.jaxpr, mlir.TokenSet(),
|
||||
cond_consts, *(x + z), dim_var_values=ctx.dim_var_values)
|
||||
ctx.module_context, cond_jaxpr.jaxpr, body_pred_name_stack,
|
||||
mlir.TokenSet(), cond_consts, *(x + z),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
new_z = _map(
|
||||
partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z,
|
||||
body_jaxpr.out_avals)
|
||||
|
@ -3813,11 +3813,11 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions):
|
||||
ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
|
||||
reducer = op.regions[0].blocks.append(*(ir_types + ir_types))
|
||||
with ir.InsertionPoint(reducer):
|
||||
reducer_ctx = ctx.module_context.replace(
|
||||
name_stack=source_info_util.new_name_stack())
|
||||
name_stack = source_info_util.new_name_stack()
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Cannot lower effectful `reduce`.')
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(reducer_ctx, jaxpr.jaxpr, mlir.TokenSet(),
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr.jaxpr,
|
||||
name_stack, mlir.TokenSet(),
|
||||
jaxpr.consts,
|
||||
*([a] for a in reducer.arguments),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
|
@ -2493,13 +2493,12 @@ def _scatter_lower(ctx, operand, indices, updates, *,
|
||||
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), aval_out.dtype))
|
||||
update = op.update_computation.blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(update):
|
||||
update_ctx = ctx.module_context.replace(
|
||||
name_stack=source_info_util.new_name_stack())
|
||||
name_stack = source_info_util.new_name_stack()
|
||||
if update_jaxpr.effects:
|
||||
raise NotImplementedError('Cannot lower effectful `scatter`.')
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(
|
||||
update_ctx, update_jaxpr, mlir.TokenSet(), update_consts,
|
||||
(update.arguments[0],), (update.arguments[1],),
|
||||
ctx.module_context, update_jaxpr, name_stack, mlir.TokenSet(),
|
||||
update_consts, (update.arguments[0],), (update.arguments[1],),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
hlo.return_(util.flatten(out_nodes))
|
||||
return op.results
|
||||
|
@ -320,7 +320,7 @@ def _generic_reduce_window_lower(ctx, *args, jaxpr, consts,
|
||||
def reducer_body(reducer: ir.Block) -> Sequence[ir.Value]:
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Cannot lower effectful `reduce_window`.')
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr,
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, ctx.name_stack,
|
||||
mlir.TokenSet(), consts, *([a] for a in reducer.arguments),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
return util.flatten(out_nodes)
|
||||
@ -529,6 +529,7 @@ def _select_and_scatter_lower(
|
||||
if select_jaxpr.effects:
|
||||
raise NotImplementedError('Cannot lower effectful `select`.')
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr,
|
||||
ctx.name_stack,
|
||||
mlir.TokenSet(), select_consts,
|
||||
*([a] for a in select.arguments),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
@ -538,6 +539,7 @@ def _select_and_scatter_lower(
|
||||
if scatter_jaxpr.effects:
|
||||
raise NotImplementedError('Cannot lower effectful `scatter`.')
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr,
|
||||
ctx.name_stack,
|
||||
mlir.TokenSet(), scatter_consts,
|
||||
*([a] for a in scatter.arguments),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
|
@ -1360,12 +1360,12 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
|
||||
# them!
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')))
|
||||
name_stack = ctx.name_stack.extend(wrap_name(name, 'xmap'))
|
||||
if any(effects.ordered_effects.contains(eff) for eff
|
||||
in vectorized_jaxpr.effects):
|
||||
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
|
||||
tiled_outs, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr, mlir.TokenSet(),
|
||||
tiled_outs, _ = mlir.jaxpr_subcomp(ctx.module_context, vectorized_jaxpr,
|
||||
name_stack, mlir.TokenSet(),
|
||||
const_nodes, *tiled_ins,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
|
||||
@ -1429,14 +1429,13 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')))
|
||||
name_stack = ctx.name_stack.extend(wrap_name(name, 'xmap'))
|
||||
if any(effects.ordered_effects.contains(eff) for eff
|
||||
in vectorized_jaxpr.effects):
|
||||
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
|
||||
global_out_nodes, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr,
|
||||
mlir.TokenSet(), const_nodes, *sharded_global_in_nodes,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
global_out_nodes, _ = mlir.jaxpr_subcomp(
|
||||
ctx.module_context, vectorized_jaxpr, name_stack, mlir.TokenSet(),
|
||||
const_nodes, *sharded_global_in_nodes, dim_var_values=ctx.dim_var_values)
|
||||
|
||||
sharded_global_out_nodes = [
|
||||
mlir.wrap_with_sharding_op(
|
||||
@ -1484,13 +1483,14 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
assert isinstance(ctx.module_context.axis_context,
|
||||
sharding_impls.SPMDAxisContext)
|
||||
name_stack = ctx.name_stack.extend(wrap_name(name, 'xmap'))
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')),
|
||||
axis_context=ctx.module_context.axis_context.extend_manual(manual_mesh_axes))
|
||||
if any(effects.ordered_effects.contains(eff) for eff
|
||||
in vectorized_jaxpr.effects):
|
||||
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
|
||||
global_out_nodes, _ = mlir.jaxpr_subcomp(sub_ctx, vectorized_jaxpr,
|
||||
global_out_nodes, _ = mlir.jaxpr_subcomp(
|
||||
sub_ctx, vectorized_jaxpr, name_stack,
|
||||
mlir.TokenSet(), const_nodes, *([n] for n in global_in_nodes),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
|
||||
|
@ -1602,9 +1602,9 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
|
||||
# inputs or outputs because they are lost during MLIR->HLO conversion.
|
||||
# using_sharding_annotation=False means we add an identity operation instead.
|
||||
func = mlir.lower_jaxpr_to_fun(
|
||||
mod_ctx, name, jaxpr, effects, arg_shardings=arg_shardings,
|
||||
result_shardings=result_shardings, use_sharding_annotations=False,
|
||||
api_name=api_name)
|
||||
mod_ctx, name, jaxpr, effects, ctx.name_stack,
|
||||
arg_shardings=arg_shardings, result_shardings=result_shardings,
|
||||
use_sharding_annotations=False, api_name=api_name)
|
||||
mod_ctx.cached_primitive_lowerings[key] = func
|
||||
return func
|
||||
|
||||
|
@ -676,14 +676,14 @@ def _wrap_main_func(
|
||||
module_context = mlir.ModuleContext(
|
||||
backend_or_name="cpu", platforms=["cpu"],
|
||||
axis_context=sharding_impls.ShardingContext(0),
|
||||
name_stack=source_info_util.new_name_stack(),
|
||||
keepalives=[], channel_iterator=itertools.count(1),
|
||||
host_callbacks=[], module=wrapped_module, context=context,
|
||||
lowering_parameters=mlir.LoweringParameters(
|
||||
global_constant_computation=True
|
||||
))
|
||||
ctx = mlir.LoweringRuleContext(
|
||||
module_context=module_context, primitive=None,
|
||||
module_context=module_context,
|
||||
name_stack=source_info_util.new_name_stack(), primitive=None,
|
||||
avals_in=args_avals_flat, avals_out=None,
|
||||
tokens_in=mlir.TokenSet(), tokens_out=None)
|
||||
# We compute dim_values from the array arguments.
|
||||
|
@ -548,9 +548,9 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
|
||||
)
|
||||
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
|
||||
with core.extend_axis_env_nd(tuple(mesh.shape.items())):
|
||||
out_nodes_, tokens_out = mlir._call_lowering(
|
||||
"shmap_body", (), jaxpr, None, sub_ctx, in_avals_, out_avals_,
|
||||
ctx.tokens_in, *in_nodes_, dim_var_values=ctx.dim_var_values,
|
||||
out_nodes_, tokens_out = mlir.call_lowering(
|
||||
"shmap_body", ctx.name_stack, jaxpr, None, sub_ctx, in_avals_,
|
||||
out_avals_, ctx.tokens_in, *in_nodes_, dim_var_values=ctx.dim_var_values,
|
||||
arg_names=map(_pspec_mhlo_attrs, in_names, in_avals_),
|
||||
result_names=map(_pspec_mhlo_attrs, out_names, out_avals_))
|
||||
ctx.set_tokens_out(tokens_out)
|
||||
|
@ -27,7 +27,7 @@ from jax._src.interpreters.mlir import (
|
||||
Token as Token,
|
||||
TokenSet as TokenSet,
|
||||
Value as Value,
|
||||
_call_lowering as _call_lowering,
|
||||
call_lowering as _call_lowering,
|
||||
_lowerings as _lowerings,
|
||||
_platform_specific_lowerings as _platform_specific_lowerings,
|
||||
aval_to_ir_type as aval_to_ir_type,
|
||||
|
Loading…
x
Reference in New Issue
Block a user