Don't enter into JaxprEqnContext's manager on every equation in jaxpr_subcomp. Only do it where it's needed.

For threefry_partitionable flag, we only need to enter into the contexts inside `mlir.lower_fun`

PiperOrigin-RevId: 647416417
This commit is contained in:
Yash Katariya 2024-06-27 12:53:40 -07:00 committed by jax authors
parent 731763408b
commit 0230d0be3d

View File

@ -16,6 +16,7 @@
from __future__ import annotations
import collections
import contextlib
from collections.abc import Callable, Iterator, Sequence
import dataclasses
import functools
@ -682,7 +683,7 @@ class LoweringRuleContext:
# The values for the dimension variables in same order as
# module_context.shape_poly_state.dim_vars
dim_var_values: Sequence[ir.Value] = ()
compute_type: str | None = None
jaxpr_eqn_ctx: core.JaxprEqnContext | None = None
# Override module_context.platforms if not None. Used during multi-platform
# lowering, when in a scope with a subset of the module_context.platforms.
platforms: Sequence[str] | None = None
@ -1544,8 +1545,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
source_info = eqn.source_info.replace(
name_stack=name_stack + eqn.source_info.name_stack)
loc = _source_info_to_location(ctx, eqn.primitive, eqn.params, source_info)
with (source_info_util.user_context(eqn.source_info.traceback), loc,
eqn.ctx.manager):
with source_info_util.user_context(eqn.source_info.traceback), loc:
override_rule = get_override_lowering_rule(eqn.primitive)
platform_rules: dict[str, LoweringRule] = {}
default_rule: LoweringRule | None = None
@ -1568,14 +1568,12 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
tokens_in = tokens.subset(effects)
avals_in = map(aval, eqn.invars)
compute_type = eqn.ctx.compute_type if eqn.ctx is not None else None
rule_ctx = LoweringRuleContext(
module_context=ctx, primitive=eqn.primitive,
name_stack=source_info.name_stack,
avals_in=avals_in,
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
tokens_out=None, dim_var_values=dim_var_values,
compute_type=compute_type)
tokens_out=None, jaxpr_eqn_ctx=eqn.ctx, dim_var_values=dim_var_values)
if config.dynamic_shapes.value:
axis_size_env = {d: read(d)[0]
for a in avals_in if type(a) is core.DShapedArray
@ -1767,37 +1765,40 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
def f_lowered(ctx: LoweringRuleContext, *args, **params):
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
wrapped_fun = lu.wrap_init(f, params)
manager = (contextlib.nullcontext() if ctx.jaxpr_eqn_ctx is None else
ctx.jaxpr_eqn_ctx.manager)
if config.dynamic_shapes.value:
# We might be applying this function to arguments with dynamic shapes,
# i.e. there might be Vars in the shape tuples of ctx.avals_in. In that
# case, we need to form a jaxpr with leading binders for those axis size
# arguments (by computing an InputType and using trace_to_jaxpr_dynamic2),
# and we need to call jaxpr_subcomp with these arguments made explicit.
assert ctx.axis_size_env is not None
args = (*ctx.axis_size_env.values(), *args)
idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)}
i32_aval = core.ShapedArray((), np.dtype('int32'))
implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env)
explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) # type: ignore
if type(a) is core.DShapedArray else a, True)
for a in ctx.avals_in]
wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args))
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
else:
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
with manager:
if config.dynamic_shapes.value:
# We might be applying this function to arguments with dynamic shapes,
# i.e. there might be Vars in the shape tuples of ctx.avals_in. In that
# case, we need to form a jaxpr with leading binders for those axis size
# arguments (by computing an InputType and using trace_to_jaxpr_dynamic2),
# and we need to call jaxpr_subcomp with these arguments made explicit.
assert ctx.axis_size_env is not None
args = (*ctx.axis_size_env.values(), *args)
idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)}
i32_aval = core.ShapedArray((), np.dtype('int32'))
implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env)
explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) # type: ignore
if type(a) is core.DShapedArray else a, True)
for a in ctx.avals_in]
wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args))
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
else:
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
if ctx.platforms is not None:
sub_context = ctx.module_context.replace(platforms=ctx.platforms)
else:
sub_context = ctx.module_context
out, tokens = jaxpr_subcomp(
sub_context, jaxpr, ctx.name_stack, ctx.tokens_in,
_ir_consts(consts), *map(wrap_singleton_ir_values, args),
dim_var_values=ctx.dim_var_values)
ctx.set_tokens_out(tokens)
return out
if ctx.platforms is not None:
sub_context = ctx.module_context.replace(platforms=ctx.platforms)
else:
sub_context = ctx.module_context
out, tokens = jaxpr_subcomp(
sub_context, jaxpr, ctx.name_stack, ctx.tokens_in,
_ir_consts(consts), *map(wrap_singleton_ir_values, args),
dim_var_values=ctx.dim_var_values)
ctx.set_tokens_out(tokens)
return out
return f_lowered
@ -1886,9 +1887,9 @@ def map_compute_type(c_type):
'are `device_host` and `device`')
def wrap_compute_type_in_place(ctx, op):
if ctx.compute_type is not None:
if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.compute_type is not None:
dict_attr = {"_xla_compute_type": ir.StringAttr.get(
map_compute_type(ctx.compute_type))}
map_compute_type(ctx.jaxpr_eqn_ctx.compute_type))}
op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)