mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #17423 from gnecula:export_multi2
PiperOrigin-RevId: 569993648
This commit is contained in:
commit
4471abe1c6
@ -418,7 +418,6 @@ class ShapePolyLoweringState:
|
||||
self.has_platform_index_argument = False
|
||||
self.dim_vars = dim_vars
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LoweringParameters:
|
||||
# A mapping between primitives and user-defined LoweringRules.
|
||||
@ -429,7 +428,8 @@ class LoweringParameters:
|
||||
# The current lowering platforms, a non-empty tuple containing some of
|
||||
# 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are
|
||||
# doing multi-platform lowering, otherwise it can specify cross-platform
|
||||
# lowering. The value None specify default lowering platform.
|
||||
# lowering. The value None specifies default lowering platform, for the
|
||||
# platform specified by `ModuleContext.platform`.
|
||||
# This is used only in export and jax2tf.
|
||||
platforms: tuple[str, ...] | None = None
|
||||
|
||||
@ -446,6 +446,10 @@ class LoweringParameters:
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_multi_platform(self) -> bool:
|
||||
return self.platforms is not None and len(self.platforms) > 1
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModuleContext:
|
||||
@ -689,7 +693,7 @@ def lower_jaxpr_to_module(
|
||||
*,
|
||||
ordered_effects: list[core.Effect],
|
||||
backend_or_name: str | xb.XlaBackend | None,
|
||||
platform: str | tuple[str, ...],
|
||||
platform: str,
|
||||
axis_context: AxisContext,
|
||||
name_stack: source_info_util.NameStack,
|
||||
donated_args: Sequence[bool],
|
||||
@ -708,14 +712,14 @@ def lower_jaxpr_to_module(
|
||||
Handles the quirks of the argument/return value passing conventions of the
|
||||
runtime.
|
||||
"""
|
||||
if lowering_parameters.platforms is not None:
|
||||
# Only for multi-platform lowering
|
||||
# TODO(necula): for now we lower only for the first platform
|
||||
platform = lowering_parameters.platforms[0]
|
||||
platforms: tuple[str, ...]
|
||||
platform = xb.canonicalize_platform(platform)
|
||||
if lowering_parameters.is_multi_platform:
|
||||
platforms = tuple(map(xb.canonicalize_platform,
|
||||
lowering_parameters.platforms)) # type: ignore
|
||||
else:
|
||||
platforms = (platform,)
|
||||
|
||||
platform = xb.canonicalize_platform(platform) # type: ignore
|
||||
if not xb.is_known_platform(platform):
|
||||
raise ValueError(f"Unknown platform {platform}")
|
||||
input_output_aliases = None
|
||||
in_avals = (jaxpr.in_avals if arg_shardings is None else
|
||||
map(sharded_aval, jaxpr.in_avals, arg_shardings))
|
||||
@ -730,7 +734,14 @@ def lower_jaxpr_to_module(
|
||||
result_memory_kinds = (map(_get_mem_kind, result_shardings)
|
||||
if result_shardings is not None else None)
|
||||
|
||||
if platform in _platforms_with_donation:
|
||||
platforms_with_donation = [p for p in platforms
|
||||
if p in _platforms_with_donation]
|
||||
if platforms_with_donation:
|
||||
if len(platforms_with_donation) != len(platforms):
|
||||
raise NotImplementedError(
|
||||
"In multi-platform lowering either all or no lowering platforms "
|
||||
f"should support donation. Lowering for {platforms} of which "
|
||||
f"only {platforms_with_donation} support donation")
|
||||
input_output_aliases, donated_args = _set_up_aliases(
|
||||
in_avals, out_avals, donated_args, arg_memory_kinds, result_memory_kinds)
|
||||
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
|
||||
@ -739,8 +750,8 @@ def lower_jaxpr_to_module(
|
||||
if any(donated_args):
|
||||
unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d]
|
||||
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
|
||||
if platform not in _platforms_with_donation:
|
||||
msg = f"Donation is not implemented for {platform}.\n{msg}"
|
||||
if not platforms_with_donation:
|
||||
msg = f"Donation is not implemented for {platforms}.\n{msg}"
|
||||
if unused_donations:
|
||||
warnings.warn("Some donated buffers were not usable:"
|
||||
f" {', '.join(unused_donations)}.\n{msg}")
|
||||
@ -805,9 +816,17 @@ def lower_jaxpr_to_module(
|
||||
raise ValueError(
|
||||
f"Cannot lower jaxpr with verifier errors: {module_string}")
|
||||
except ir.MLIRError as e:
|
||||
module_string = module_to_string(ctx.module)
|
||||
raise ValueError(
|
||||
f"Cannot lower jaxpr with verifier errors: {module_string}") from e
|
||||
msg_lines = ["Cannot lower jaxpr with verifier errors:"]
|
||||
def emit_diagnostic_info(d):
|
||||
msg_lines.append(f"\t{d.message}")
|
||||
msg_lines.append(f"\t\tat {d.location}")
|
||||
for n in d.notes:
|
||||
emit_diagnostic_info(n)
|
||||
for d in e.error_diagnostics:
|
||||
emit_diagnostic_info(d)
|
||||
msg_lines.append("Module string:")
|
||||
msg_lines.append(module_to_string(ctx.module))
|
||||
raise ValueError("\n".join(msg_lines)) from e
|
||||
|
||||
return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks,
|
||||
ctx.shape_poly_state)
|
||||
@ -1314,7 +1333,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
assert node is not None
|
||||
env[v] = tuple(node)
|
||||
|
||||
def get_lowering(primitive: core.Primitive) -> LoweringRule | None:
|
||||
def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None:
|
||||
if ctx.lowering_parameters.override_lowering_rules is None:
|
||||
return None
|
||||
for p, rule in ctx.lowering_parameters.override_lowering_rules:
|
||||
@ -1339,21 +1358,46 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
loc = _source_info_to_location(eqn.primitive, eqn.params, source_info,
|
||||
ctx.name_stack)
|
||||
with source_info_util.user_context(eqn.source_info.traceback), loc:
|
||||
override_rule = get_lowering(eqn.primitive)
|
||||
if override_rule is not None:
|
||||
rule = override_rule
|
||||
elif eqn.primitive in _platform_specific_lowerings[ctx.platform]:
|
||||
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive]
|
||||
elif eqn.primitive in xla._backend_specific_translations[ctx.platform]:
|
||||
rule = xla_fallback_lowering(eqn.primitive)
|
||||
elif eqn.primitive in _lowerings:
|
||||
rule = _lowerings[eqn.primitive]
|
||||
elif eqn.primitive in xla._translations:
|
||||
rule = xla_fallback_lowering(eqn.primitive)
|
||||
override_rule = get_override_lowering_rule(eqn.primitive)
|
||||
if not ctx.lowering_parameters.is_multi_platform:
|
||||
# Classic, single-platform lowering
|
||||
# TODO(necula): unify the code paths when multi-platform is finished
|
||||
if override_rule is not None:
|
||||
rule = override_rule
|
||||
elif eqn.primitive in _platform_specific_lowerings[ctx.platform]:
|
||||
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive]
|
||||
elif eqn.primitive in xla._backend_specific_translations[ctx.platform]:
|
||||
rule = xla_fallback_lowering(eqn.primitive)
|
||||
elif eqn.primitive in _lowerings:
|
||||
rule = _lowerings[eqn.primitive]
|
||||
elif eqn.primitive in xla._translations:
|
||||
rule = xla_fallback_lowering(eqn.primitive)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
|
||||
f"found for platform {ctx.platform}")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
|
||||
f"found for platform {ctx.platform}")
|
||||
rules: list[MultiPlatformLoweringRule]
|
||||
# See mlir.lower_multi_platform for the `rules` format
|
||||
if override_rule is not None:
|
||||
rules = [(None, override_rule)]
|
||||
else:
|
||||
# First the platform-specific rules
|
||||
rules = []
|
||||
for p in ctx.lowering_parameters.platforms: # type: ignore
|
||||
if eqn.primitive in _platform_specific_lowerings[p]:
|
||||
rules.append(
|
||||
([p], _platform_specific_lowerings[p][eqn.primitive]))
|
||||
elif eqn.primitive in xla._backend_specific_translations[p]:
|
||||
rules.append(
|
||||
([p], xla_fallback_lowering(eqn.primitive)))
|
||||
# Now the catch-all rules
|
||||
if eqn.primitive in _lowerings:
|
||||
rules.append(
|
||||
(None, _lowerings[eqn.primitive])) # type: ignore
|
||||
elif eqn.primitive in xla._translations:
|
||||
rules.append(
|
||||
(None, xla_fallback_lowering(eqn.primitive))) # type: ignore
|
||||
|
||||
eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
|
||||
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
|
||||
@ -1368,8 +1412,15 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
for a in avals_in if type(a) is core.DShapedArray
|
||||
for d in a.shape if type(d) is core.Var}
|
||||
rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
|
||||
ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
|
||||
**eqn.params)
|
||||
|
||||
rule_inputs = map(_unwrap_singleton_ir_values, in_nodes)
|
||||
if not ctx.lowering_parameters.is_multi_platform:
|
||||
# Classic, single-platform lowering
|
||||
ans = rule(rule_ctx, *rule_inputs, **eqn.params)
|
||||
else:
|
||||
ans = lower_multi_platform(rule_ctx, str(eqn), rules,
|
||||
*rule_inputs, **eqn.params)
|
||||
|
||||
if effects:
|
||||
# If there were ordered effects in the primitive, there should be output
|
||||
# tokens we need for subsequent ordered effects.
|
||||
@ -1399,6 +1450,115 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
core.clean_up_dead_vars(eqn, env, last_used)
|
||||
return map(read, jaxpr.outvars), tokens
|
||||
|
||||
# See docstring for lower_multi_platform.
|
||||
MultiPlatformLoweringRule = tuple[Optional[Sequence[str]], Callable]
|
||||
|
||||
def lower_multi_platform(ctx: LoweringRuleContext,
|
||||
description: str,
|
||||
rules: Sequence[MultiPlatformLoweringRule],
|
||||
*rule_args: ir.Value,
|
||||
**rule_kwargs) -> ir.Value:
|
||||
"""Emits single- or multi-platform code for a primitive.
|
||||
|
||||
For example, given
|
||||
ctx.module_context.lowering_parameters.platforms = ("cpu", "gpu", "tpu")
|
||||
and
|
||||
rules = [(["tpu", "cpu"], rule0),
|
||||
(None, rule1)
|
||||
emits:
|
||||
rule_idx = case current_platform_idx:
|
||||
0: return 0 # cpu rule index
|
||||
1: return 1 # gpu rule index
|
||||
2: return 0 # tpu rule index
|
||||
output = case rule_idx
|
||||
0: return rule0(*rule_args, **rule_kwargs)
|
||||
1: return rule1(*rule_args, **rule_kwargs)
|
||||
|
||||
If the primitive has a single lowering rule for all platforms of interest,
|
||||
skips the conditionals and emits the same code as for classic single-platform
|
||||
lowering.
|
||||
|
||||
Args:
|
||||
ctx: lowering context.
|
||||
description: a string to include in error messages.
|
||||
rules: a sequence of per-platform rules. Each entry is a tuple, with the
|
||||
first element specifying the platforms, either a sequence of applicable
|
||||
platform names (maybe empty), or None to denote a default entry to use
|
||||
when no other entry applies. The second element of the tuple is a
|
||||
lowering rule, i.e., a function to invoke with a
|
||||
LoweringRuleContext (a sub-context of `ctx`),
|
||||
and `*rule_args` and `**rule_kwargs`.
|
||||
rule_args: the args of the lowering rules.
|
||||
rule_kwargs: the kwargs of the lowering rules.
|
||||
"""
|
||||
assert isinstance(ctx.module_context.lowering_parameters.platforms, tuple)
|
||||
platforms = ctx.module_context.lowering_parameters.platforms
|
||||
platforms_with_specific_rules = util.flatten(
|
||||
[ps for ps, _ in rules if ps is not None])
|
||||
platforms_with_default_rule = [p for p in platforms
|
||||
if p not in platforms_with_specific_rules]
|
||||
kept_rules: list[MultiPlatformLoweringRule] = [] # Only the rules for platforms of interest
|
||||
platform_to_kept_rules_idx: dict[str, int] = {}
|
||||
for ps, r in rules:
|
||||
rule_index = len(kept_rules)
|
||||
if ps is not None:
|
||||
# Keep only rules that mention the platforms of interest
|
||||
interesting_ps = [p for p in platforms if p in ps]
|
||||
if interesting_ps:
|
||||
for p in interesting_ps:
|
||||
assert p not in platform_to_kept_rules_idx
|
||||
platform_to_kept_rules_idx[p] = rule_index
|
||||
kept_rules.append((interesting_ps, r))
|
||||
elif platforms_with_default_rule:
|
||||
for p in platforms_with_default_rule:
|
||||
assert p not in platform_to_kept_rules_idx
|
||||
platform_to_kept_rules_idx[p] = rule_index
|
||||
kept_rules.append((platforms_with_default_rule, r))
|
||||
|
||||
platforms_without_rules = [p for p in platforms
|
||||
if p not in platform_to_kept_rules_idx]
|
||||
if platforms_without_rules:
|
||||
raise ValueError(
|
||||
f"MLIR translation rule for primitive '{description}' not "
|
||||
f"found for platforms {platforms_without_rules}")
|
||||
assert kept_rules
|
||||
|
||||
# Maybe there is a single rule left, just apply the rule, no conditionals.
|
||||
if len(kept_rules) == 1:
|
||||
return kept_rules[0][1](ctx, *rule_args, **rule_kwargs)
|
||||
|
||||
assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules)
|
||||
assert len(ctx.dim_var_values) >= 1, "Must have a platform_index variable"
|
||||
|
||||
# The first dim_var_values is the platform index
|
||||
current_platform_idx = ctx.dim_var_values[0]
|
||||
# Compute the rule index based on the current platform
|
||||
i32_type = aval_to_ir_types(core.ShapedArray((), dtype=np.int32))[0]
|
||||
if current_platform_idx.type != i32_type:
|
||||
current_platform_idx = hlo.ConvertOp(i32_type, current_platform_idx)
|
||||
rule_idx_op = hlo.CaseOp([i32_type],
|
||||
index=current_platform_idx,
|
||||
num_branches=len(platforms))
|
||||
for i, p in enumerate(platforms):
|
||||
branch = rule_idx_op.regions[i].blocks.append()
|
||||
with ir.InsertionPoint(branch):
|
||||
hlo.ReturnOp(ir_constants(np.int32(platform_to_kept_rules_idx[p])))
|
||||
case_op = hlo.CaseOp(util.flatten(map(aval_to_ir_types, ctx.avals_out)),
|
||||
index=rule_idx_op,
|
||||
num_branches=len(kept_rules))
|
||||
for i, (_, rule) in enumerate(kept_rules):
|
||||
inner_ctx = ctx.replace()
|
||||
branch = case_op.regions[i].blocks.append()
|
||||
with ir.InsertionPoint(branch):
|
||||
output = rule(inner_ctx, *rule_args, **rule_kwargs)
|
||||
try:
|
||||
out_nodes = map(wrap_singleton_ir_values, output)
|
||||
except TypeError as e:
|
||||
raise ValueError("Output of translation rule must be iterable: "
|
||||
f"{description}, got output {output}") from e
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
|
||||
return case_op.results
|
||||
|
||||
def _ir_consts(consts):
|
||||
unique_consts = {id(const): const for const in consts}
|
||||
|
@ -2219,7 +2219,7 @@ def lower_mesh_computation(
|
||||
closed_jaxpr,
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
platform=lowering_parameters.platforms or backend.platform,
|
||||
platform=lowering_parameters.override_platform or backend.platform,
|
||||
axis_context=axis_ctx,
|
||||
name_stack=name_stack,
|
||||
donated_args=donated_invars,
|
||||
|
@ -16,7 +16,7 @@ import functools
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from typing import Optional
|
||||
from typing import Optional, Sequence
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -38,6 +38,43 @@ import numpy as np
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
# A primitive for testing multi-platform lowering. Takes one argument and
|
||||
# adds a different value to it: cpu=2., tpu=3., cuda=.4, rocm=5.
|
||||
_testing_multi_platform_p = core.Primitive("testing_multi_platform")
|
||||
_testing_multi_platform_to_add = dict(cpu=2., tpu=3., cuda=4., rocm=5.)
|
||||
|
||||
@_testing_multi_platform_p.def_abstract_eval
|
||||
def _testing_multi_platform_abstract_eval(xaval: core.AbstractValue):
|
||||
assert xaval.dtype == np.float32 # type: ignore
|
||||
return xaval
|
||||
|
||||
def _testing_multi_platform_lowering(ctx: mlir.LoweringRuleContext,
|
||||
x: mlir.Value,
|
||||
*,
|
||||
platform: str) -> Sequence[mlir.Value]:
|
||||
to_add = _testing_multi_platform_to_add[platform]
|
||||
to_add_value = mlir.broadcast_in_dim(ctx,
|
||||
mlir.ir_constant(np.float32(to_add)),
|
||||
ctx.avals_in[0],
|
||||
broadcast_dimensions=())
|
||||
return mlir.hlo.AddOp(x, to_add_value).results
|
||||
|
||||
# Register a default rule for cuda, to test the default-platform rule selection.
|
||||
mlir.register_lowering(_testing_multi_platform_p,
|
||||
functools.partial(_testing_multi_platform_lowering,
|
||||
platform="cuda"))
|
||||
for platform in ["cpu", "tpu", "rocm"]:
|
||||
mlir.register_lowering(_testing_multi_platform_p,
|
||||
functools.partial(_testing_multi_platform_lowering,
|
||||
platform=platform),
|
||||
platform=platform)
|
||||
|
||||
def _testing_multi_platform_func(x):
|
||||
return _testing_multi_platform_p.bind(x)
|
||||
|
||||
def _testing_multi_platform_fun_expected(x):
|
||||
return x + _testing_multi_platform_to_add[xb.canonicalize_platform(jtu.device_under_test())]
|
||||
|
||||
|
||||
class JaxExportTest(jtu.JaxTestCase):
|
||||
|
||||
@ -568,54 +605,50 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
# The export is not applicable to GPU
|
||||
raise unittest.SkipTest("Not intended for running on GPU")
|
||||
x = np.arange(5, dtype=np.float32)
|
||||
# TODO: use a function with different behavior for different platforms
|
||||
exp = export.export(jnp.sin,
|
||||
exp = export.export(_testing_multi_platform_func,
|
||||
lowering_platforms=('cpu', 'tpu'))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu'))
|
||||
module_str = str(exp.mlir_module())
|
||||
platform_index = re.findall(
|
||||
r"(%arg\d):\s*tensor<i..>\s*{jax.platform_index = true}",
|
||||
module_str)
|
||||
self.assertEqual(["%arg0"], platform_index,
|
||||
f"Found {platform_index} in {module_str}")
|
||||
expected_main_re = (
|
||||
r"@main\("
|
||||
r"%arg0: tensor<i..> {jax.platform_index = true}.*, "
|
||||
r"%arg1: tensor<5xf32>.* ->")
|
||||
self.assertRegex(module_str, expected_main_re)
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(res, np.sin(x))
|
||||
self.assertAllClose(res, _testing_multi_platform_fun_expected(x))
|
||||
|
||||
def test_multi_platform_nested(self):
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
# The outer export is not applicable to TPU
|
||||
raise unittest.SkipTest("Not intended for running on TPU")
|
||||
x = np.arange(5, dtype=np.float32)
|
||||
# TODO: use a function with different behavior for different platforms
|
||||
exp = export.export(jnp.sin,
|
||||
lowering_platforms=('cpu', 'tpu', 'cuda'))(x)
|
||||
exp = export.export(_testing_multi_platform_func,
|
||||
lowering_platforms=('cpu', 'tpu', 'cuda'))(x)
|
||||
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu', 'cuda'))
|
||||
|
||||
# Now serialize the call to the exported using a different sequence of
|
||||
# lowering platforms, but included in the lowering platforms for the
|
||||
# nested exported.
|
||||
# TODO: improve this test once we implement true multi-platform lowering
|
||||
exp2 = export.export(export.call_exported(exp),
|
||||
lowering_platforms=('cpu', 'cuda'))(x)
|
||||
lowering_platforms=('cpu', 'cuda'))(x)
|
||||
res2 = export.call_exported(exp2)(x)
|
||||
self.assertAllClose(res2, np.sin(x))
|
||||
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x))
|
||||
|
||||
def test_multi_platform_and_poly(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
# The export is not applicable to GPU
|
||||
raise unittest.SkipTest("Not intended for running on GPU")
|
||||
# TODO: use a function with different behavior for different platforms
|
||||
exp = export.export(lambda x: jnp.reshape(jnp.sin(x), (-1,)),
|
||||
lowering_platforms=('cpu', 'tpu'))(
|
||||
exp = export.export(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,)),
|
||||
lowering_platforms=('cpu', 'tpu'))(
|
||||
export.poly_spec((5, 6), np.float32, "b1, b2")
|
||||
)
|
||||
x = np.arange(12, dtype=np.float32).reshape((3, 4))
|
||||
res = export.call_exported(exp)(x)
|
||||
self.assertAllClose(res, np.sin(x).reshape((-1,)))
|
||||
self.assertAllClose(res, _testing_multi_platform_fun_expected(x).reshape((-1,)))
|
||||
# Now serialize the call to the exported
|
||||
exp2 = export.export(export.call_exported(exp))(x)
|
||||
res2 = export.call_exported(exp2)(x)
|
||||
self.assertAllClose(res2, np.sin(x).reshape((-1,)))
|
||||
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x).reshape((-1,)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user