mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #18369 from gnecula:lower_clean
PiperOrigin-RevId: 583994157
This commit is contained in:
commit
b8fe80931e
@ -1416,46 +1416,23 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
ctx.name_stack)
|
||||
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
||||
override_rule = get_override_lowering_rule(eqn.primitive)
|
||||
if len(ctx.platforms) == 1:
|
||||
# Classic, single-platform lowering
|
||||
# TODO(necula): unify the code paths when multi-platform is finished
|
||||
platform = ctx.platforms[0]
|
||||
if override_rule is not None:
|
||||
rule = override_rule
|
||||
elif eqn.primitive in _platform_specific_lowerings[platform]:
|
||||
rule = _platform_specific_lowerings[platform][eqn.primitive]
|
||||
elif eqn.primitive in xla._backend_specific_translations[platform]:
|
||||
rule = xla_fallback_lowering(eqn.primitive)
|
||||
elif eqn.primitive in _lowerings:
|
||||
rule = _lowerings[eqn.primitive]
|
||||
elif eqn.primitive in xla._translations:
|
||||
rule = xla_fallback_lowering(eqn.primitive)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
|
||||
f"found for platform {platform}")
|
||||
platform_rules: dict[str, LoweringRule] = {}
|
||||
default_rule: Optional[LoweringRule] = None
|
||||
# See mlir.lower_per_platform for meaning of `platform_rules` and `default_rule`
|
||||
if override_rule is not None:
|
||||
default_rule = override_rule
|
||||
else:
|
||||
rules: list[MultiPlatformLoweringRule]
|
||||
# See mlir.lower_multi_platform for the `rules` format
|
||||
if override_rule is not None:
|
||||
rules = [(None, override_rule)]
|
||||
else:
|
||||
# First the platform-specific rules
|
||||
rules = []
|
||||
for p in ctx.platforms:
|
||||
if eqn.primitive in _platform_specific_lowerings[p]:
|
||||
rules.append(
|
||||
([p], _platform_specific_lowerings[p][eqn.primitive]))
|
||||
elif eqn.primitive in xla._backend_specific_translations[p]:
|
||||
rules.append(
|
||||
([p], xla_fallback_lowering(eqn.primitive)))
|
||||
# Now the catch-all rules
|
||||
if eqn.primitive in _lowerings:
|
||||
rules.append(
|
||||
(None, _lowerings[eqn.primitive])) # type: ignore
|
||||
elif eqn.primitive in xla._translations:
|
||||
rules.append(
|
||||
(None, xla_fallback_lowering(eqn.primitive))) # type: ignore
|
||||
# First the platform-specific rules
|
||||
for p in ctx.platforms:
|
||||
if eqn.primitive in _platform_specific_lowerings[p]:
|
||||
platform_rules[p] = _platform_specific_lowerings[p][eqn.primitive]
|
||||
elif eqn.primitive in xla._backend_specific_translations[p]:
|
||||
platform_rules[p] = xla_fallback_lowering(eqn.primitive)
|
||||
# Now the default rule
|
||||
if eqn.primitive in _lowerings:
|
||||
default_rule = _lowerings[eqn.primitive]
|
||||
elif eqn.primitive in xla._translations:
|
||||
default_rule = xla_fallback_lowering(eqn.primitive)
|
||||
|
||||
eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
|
||||
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
|
||||
@ -1473,13 +1450,10 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
|
||||
|
||||
rule_inputs = map(_unwrap_singleton_ir_values, in_nodes)
|
||||
if len(ctx.platforms) == 1:
|
||||
# Classic, single-platform lowering
|
||||
ans = rule(rule_ctx, *rule_inputs, **eqn.params)
|
||||
else:
|
||||
ans = lower_multi_platform(rule_ctx, str(eqn), rules,
|
||||
eqn.effects,
|
||||
*rule_inputs, **eqn.params)
|
||||
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
|
||||
platform_rules, default_rule,
|
||||
eqn.effects,
|
||||
*rule_inputs, **eqn.params)
|
||||
|
||||
if effects:
|
||||
# If there were ordered effects in the primitive, there should be output
|
||||
@ -1510,82 +1484,82 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
core.clean_up_dead_vars(eqn, env, last_used)
|
||||
return map(read, jaxpr.outvars), tokens
|
||||
|
||||
# See docstring for lower_multi_platform.
|
||||
MultiPlatformLoweringRule = tuple[Optional[Sequence[str]], Callable]
|
||||
|
||||
def lower_multi_platform(ctx: LoweringRuleContext,
|
||||
description: str,
|
||||
rules: Sequence[MultiPlatformLoweringRule],
|
||||
effects: effects_lib.Effects,
|
||||
*rule_args: ir.Value,
|
||||
**rule_kwargs) -> ir.Value:
|
||||
"""Emits single- or multi-platform code for a primitive.
|
||||
|
||||
def lower_per_platform(ctx: LoweringRuleContext,
|
||||
description: str,
|
||||
platform_rules: dict[str, LoweringRule],
|
||||
default_rule: Optional[LoweringRule],
|
||||
effects: effects_lib.Effects,
|
||||
*rule_args: ir.Value,
|
||||
**rule_kwargs) -> ir.Value:
|
||||
"""Emits code for a primitive for the current lowering platform(s).
|
||||
|
||||
For example, given
|
||||
ctx.module_context.lowering_parameters.platforms = ("cpu", "gpu", "tpu")
|
||||
platform_rules = dict(tpu=rule0, cpu=rule0)
|
||||
default_rule = rule1
|
||||
|
||||
and
|
||||
rules = [(["tpu", "cpu"], rule0),
|
||||
(None, rule1)
|
||||
ctx.module_context.lowering_parameters.platforms = ("cpu",)
|
||||
|
||||
emits:
|
||||
rule0(ctx, *rule_args, **rule_kwargs)
|
||||
|
||||
In case of multi-platform lowering, e.g., if
|
||||
ctx.module_context.lowering_parameters.platforms = ("cpu", "cuda", "tpu")
|
||||
|
||||
emits:
|
||||
rule_idx = case current_platform_idx:
|
||||
0: return 0 # cpu rule index
|
||||
1: return 1 # gpu rule index
|
||||
1: return 1 # cuda rule index
|
||||
2: return 0 # tpu rule index
|
||||
output = case rule_idx
|
||||
0: return rule0(*rule_args, **rule_kwargs)
|
||||
1: return rule1(*rule_args, **rule_kwargs)
|
||||
|
||||
If the primitive has a single lowering rule for all platforms of interest,
|
||||
skips the conditionals and emits the same code as for classic single-platform
|
||||
lowering.
|
||||
|
||||
Args:
|
||||
ctx: lowering context.
|
||||
description: a string to include in error messages.
|
||||
rules: a sequence of per-platform rules. Each entry is a tuple, with the
|
||||
first element specifying the platforms, either a sequence of applicable
|
||||
platform names (maybe empty), or None to denote a default entry to use
|
||||
when no other entry applies. The second element of the tuple is a
|
||||
lowering rule, i.e., a function to invoke with a
|
||||
LoweringRuleContext (a sub-context of `ctx`),
|
||||
and `*rule_args` and `**rule_kwargs`.
|
||||
platform_rules: map platform names, e.g., "cpu", "cuda", to
|
||||
`LoweringRule`s, for the platforms that have non-default lowering.
|
||||
default_rule: an optional rule to use for platforms not in `platform_rules`.
|
||||
effects: the set of effects for the current primitive.
|
||||
rule_args: the args of the lowering rules.
|
||||
rule_kwargs: the kwargs of the lowering rules.
|
||||
"""
|
||||
platforms: Sequence[str] = ctx.module_context.platforms
|
||||
platforms_with_specific_rules: Sequence[str] = util.flatten(
|
||||
[ps for ps, _ in rules if ps is not None])
|
||||
platforms_with_default_rule = [p for p in platforms
|
||||
if p not in platforms_with_specific_rules]
|
||||
kept_rules: list[MultiPlatformLoweringRule] = [] # Only the rules for platforms of interest
|
||||
# Special case the common case (single-platform lowering)
|
||||
if len(platforms) == 1:
|
||||
rule = platform_rules.get(platforms[0], default_rule)
|
||||
if rule is None:
|
||||
raise NotImplementedError(
|
||||
f"MLIR translation rule for primitive '{description}' not "
|
||||
f"found for platform {platforms[0]}")
|
||||
|
||||
# Multi-platform lowering
|
||||
kept_rules: list[LoweringRule] = [] # Only the rules for the platforms of interest
|
||||
platform_to_kept_rules_idx: dict[str, int] = {}
|
||||
for ps, r in rules:
|
||||
rule_index = len(kept_rules)
|
||||
if ps is not None:
|
||||
# Keep only rules that mention the platforms of interest
|
||||
interesting_ps = [p for p in platforms if p in ps] # type: ignore
|
||||
if interesting_ps:
|
||||
for p in interesting_ps:
|
||||
assert p not in platform_to_kept_rules_idx
|
||||
platform_to_kept_rules_idx[p] = rule_index
|
||||
kept_rules.append((interesting_ps, r))
|
||||
elif platforms_with_default_rule:
|
||||
for p in platforms_with_default_rule:
|
||||
assert p not in platform_to_kept_rules_idx
|
||||
platform_to_kept_rules_idx[p] = rule_index
|
||||
kept_rules.append((platforms_with_default_rule, r))
|
||||
for p, prule in platform_rules.items():
|
||||
if p not in platforms:
|
||||
continue
|
||||
platform_to_kept_rules_idx[p] = len(kept_rules)
|
||||
kept_rules.append(prule)
|
||||
|
||||
platforms_without_specific_rule = [p for p in platforms
|
||||
if p not in platform_to_kept_rules_idx]
|
||||
if platforms_without_specific_rule:
|
||||
if default_rule is None:
|
||||
raise NotImplementedError(
|
||||
f"MLIR translation rule for primitive '{description}' not "
|
||||
f"found for platforms {platforms_without_specific_rule}")
|
||||
for p in platforms_without_specific_rule:
|
||||
platform_to_kept_rules_idx[p] = len(kept_rules)
|
||||
kept_rules.append(default_rule)
|
||||
|
||||
platforms_without_rules = [p for p in platforms
|
||||
if p not in platform_to_kept_rules_idx]
|
||||
if platforms_without_rules:
|
||||
raise ValueError(
|
||||
f"MLIR translation rule for primitive '{description}' not "
|
||||
f"found for platforms {platforms_without_rules}")
|
||||
assert kept_rules
|
||||
|
||||
# Maybe there is a single rule left, just apply the rule, no conditionals.
|
||||
# If there is a single rule left just apply the rule, without conditionals.
|
||||
if len(kept_rules) == 1:
|
||||
return kept_rules[0][1](ctx, *rule_args, **rule_kwargs)
|
||||
return kept_rules[0](ctx, *rule_args, **rule_kwargs)
|
||||
|
||||
assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules)
|
||||
assert len(ctx.dim_var_values) >= 1, "Must have a platform_index variable"
|
||||
@ -1609,7 +1583,7 @@ def lower_multi_platform(ctx: LoweringRuleContext,
|
||||
case_op = hlo.CaseOp(util.flatten(output_types),
|
||||
index=rule_idx_op,
|
||||
num_branches=len(kept_rules))
|
||||
for i, (_, rule) in enumerate(kept_rules):
|
||||
for i, rule in enumerate(kept_rules):
|
||||
inner_ctx = ctx.replace()
|
||||
branch = case_op.regions[i].blocks.append()
|
||||
with ir.InsertionPoint(branch):
|
||||
|
@ -938,7 +938,7 @@ def platform_dependent(*args: Any,
|
||||
platform_branches: list[tuple[list[str], Callable]] = []
|
||||
for pname, pbranch in per_platform.items():
|
||||
if pname == "gpu":
|
||||
raise ValueError("Use 'cuda' or 'rocm' for this API.")
|
||||
raise ValueError("Use 'cuda' or 'rocm' for lax.platform_dependent.")
|
||||
for ps, b in platform_branches:
|
||||
if b == pbranch:
|
||||
ps.append(pname)
|
||||
@ -979,18 +979,17 @@ def _platform_index_lowering(ctx: mlir.LoweringRuleContext,
|
||||
has_default: bool):
|
||||
def lower_constant(ctx: mlir.LoweringRuleContext, *, i: int) -> mlir.ir.Value:
|
||||
return mlir.ir_constants(np.int32(i))
|
||||
lowering_rules: tuple[mlir.MultiPlatformLoweringRule, ...] = tuple(
|
||||
(ps, partial(lower_constant, i=i))
|
||||
for i, ps in enumerate(platforms)
|
||||
)
|
||||
if has_default:
|
||||
lowering_rules = lowering_rules + (
|
||||
(None, partial(lower_constant, i=len(platforms))),
|
||||
)
|
||||
return mlir.lower_multi_platform(
|
||||
platform_rules: dict[str, mlir.LoweringRule] = {}
|
||||
for i, ps in enumerate(platforms):
|
||||
rule = partial(lower_constant, i=i)
|
||||
for p in ps:
|
||||
platform_rules[p] = rule
|
||||
|
||||
default_rule = (
|
||||
partial(lower_constant, i=len(platforms)) if has_default else None)
|
||||
return mlir.lower_per_platform(
|
||||
ctx,
|
||||
f"platform_index(platforms={platforms}, has_default={has_default})",
|
||||
lowering_rules,
|
||||
effects.no_effects)
|
||||
platform_rules, default_rule, effects.no_effects)
|
||||
|
||||
mlir.register_lowering(platform_index_p, _platform_index_lowering)
|
||||
|
@ -2782,7 +2782,7 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
ctx = contextlib.ExitStack()
|
||||
if jtu.device_under_test() != "tpu":
|
||||
ctx.enter_context(
|
||||
self.assertRaisesRegex(ValueError,
|
||||
self.assertRaisesRegex(NotImplementedError,
|
||||
"translation rule .* not found for platform"))
|
||||
with ctx:
|
||||
lax.platform_dependent(
|
||||
|
Loading…
x
Reference in New Issue
Block a user