mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Don't enter into JaxprEqnContext
's manager on every equation in jaxpr_subcomp
. Only do it where it's needed.
For threefry_partitionable flag, we only need to enter into the contexts inside `mlir.lower_fun` PiperOrigin-RevId: 647416417
This commit is contained in:
parent
731763408b
commit
0230d0be3d
@ -16,6 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
@ -682,7 +683,7 @@ class LoweringRuleContext:
|
||||
# The values for the dimension variables in same order as
|
||||
# module_context.shape_poly_state.dim_vars
|
||||
dim_var_values: Sequence[ir.Value] = ()
|
||||
compute_type: str | None = None
|
||||
jaxpr_eqn_ctx: core.JaxprEqnContext | None = None
|
||||
# Override module_context.platforms if not None. Used during multi-platform
|
||||
# lowering, when in a scope with a subset of the module_context.platforms.
|
||||
platforms: Sequence[str] | None = None
|
||||
@ -1544,8 +1545,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
source_info = eqn.source_info.replace(
|
||||
name_stack=name_stack + eqn.source_info.name_stack)
|
||||
loc = _source_info_to_location(ctx, eqn.primitive, eqn.params, source_info)
|
||||
with (source_info_util.user_context(eqn.source_info.traceback), loc,
|
||||
eqn.ctx.manager):
|
||||
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
||||
override_rule = get_override_lowering_rule(eqn.primitive)
|
||||
platform_rules: dict[str, LoweringRule] = {}
|
||||
default_rule: LoweringRule | None = None
|
||||
@ -1568,14 +1568,12 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
|
||||
tokens_in = tokens.subset(effects)
|
||||
avals_in = map(aval, eqn.invars)
|
||||
compute_type = eqn.ctx.compute_type if eqn.ctx is not None else None
|
||||
rule_ctx = LoweringRuleContext(
|
||||
module_context=ctx, primitive=eqn.primitive,
|
||||
name_stack=source_info.name_stack,
|
||||
avals_in=avals_in,
|
||||
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
|
||||
tokens_out=None, dim_var_values=dim_var_values,
|
||||
compute_type=compute_type)
|
||||
tokens_out=None, jaxpr_eqn_ctx=eqn.ctx, dim_var_values=dim_var_values)
|
||||
if config.dynamic_shapes.value:
|
||||
axis_size_env = {d: read(d)[0]
|
||||
for a in avals_in if type(a) is core.DShapedArray
|
||||
@ -1767,37 +1765,40 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
||||
def f_lowered(ctx: LoweringRuleContext, *args, **params):
|
||||
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
||||
wrapped_fun = lu.wrap_init(f, params)
|
||||
manager = (contextlib.nullcontext() if ctx.jaxpr_eqn_ctx is None else
|
||||
ctx.jaxpr_eqn_ctx.manager)
|
||||
|
||||
if config.dynamic_shapes.value:
|
||||
# We might be applying this function to arguments with dynamic shapes,
|
||||
# i.e. there might be Vars in the shape tuples of ctx.avals_in. In that
|
||||
# case, we need to form a jaxpr with leading binders for those axis size
|
||||
# arguments (by computing an InputType and using trace_to_jaxpr_dynamic2),
|
||||
# and we need to call jaxpr_subcomp with these arguments made explicit.
|
||||
assert ctx.axis_size_env is not None
|
||||
args = (*ctx.axis_size_env.values(), *args)
|
||||
idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)}
|
||||
i32_aval = core.ShapedArray((), np.dtype('int32'))
|
||||
implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env)
|
||||
explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) # type: ignore
|
||||
if type(a) is core.DShapedArray else a, True)
|
||||
for a in ctx.avals_in]
|
||||
wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args))
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
|
||||
else:
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
|
||||
with manager:
|
||||
if config.dynamic_shapes.value:
|
||||
# We might be applying this function to arguments with dynamic shapes,
|
||||
# i.e. there might be Vars in the shape tuples of ctx.avals_in. In that
|
||||
# case, we need to form a jaxpr with leading binders for those axis size
|
||||
# arguments (by computing an InputType and using trace_to_jaxpr_dynamic2),
|
||||
# and we need to call jaxpr_subcomp with these arguments made explicit.
|
||||
assert ctx.axis_size_env is not None
|
||||
args = (*ctx.axis_size_env.values(), *args)
|
||||
idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)}
|
||||
i32_aval = core.ShapedArray((), np.dtype('int32'))
|
||||
implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env)
|
||||
explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) # type: ignore
|
||||
if type(a) is core.DShapedArray else a, True)
|
||||
for a in ctx.avals_in]
|
||||
wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args))
|
||||
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic2(wrapped_fun)
|
||||
else:
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
|
||||
|
||||
if ctx.platforms is not None:
|
||||
sub_context = ctx.module_context.replace(platforms=ctx.platforms)
|
||||
else:
|
||||
sub_context = ctx.module_context
|
||||
out, tokens = jaxpr_subcomp(
|
||||
sub_context, jaxpr, ctx.name_stack, ctx.tokens_in,
|
||||
_ir_consts(consts), *map(wrap_singleton_ir_values, args),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
ctx.set_tokens_out(tokens)
|
||||
return out
|
||||
if ctx.platforms is not None:
|
||||
sub_context = ctx.module_context.replace(platforms=ctx.platforms)
|
||||
else:
|
||||
sub_context = ctx.module_context
|
||||
out, tokens = jaxpr_subcomp(
|
||||
sub_context, jaxpr, ctx.name_stack, ctx.tokens_in,
|
||||
_ir_consts(consts), *map(wrap_singleton_ir_values, args),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
ctx.set_tokens_out(tokens)
|
||||
return out
|
||||
|
||||
return f_lowered
|
||||
|
||||
@ -1886,9 +1887,9 @@ def map_compute_type(c_type):
|
||||
'are `device_host` and `device`')
|
||||
|
||||
def wrap_compute_type_in_place(ctx, op):
|
||||
if ctx.compute_type is not None:
|
||||
if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.compute_type is not None:
|
||||
dict_attr = {"_xla_compute_type": ir.StringAttr.get(
|
||||
map_compute_type(ctx.compute_type))}
|
||||
map_compute_type(ctx.jaxpr_eqn_ctx.compute_type))}
|
||||
op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user