Merge pull request #17423 from gnecula:export_multi2

PiperOrigin-RevId: 569993648
This commit is contained in:
jax authors 2023-10-02 02:28:52 -07:00
commit 4471abe1c6
3 changed files with 247 additions and 54 deletions

View File

@ -418,7 +418,6 @@ class ShapePolyLoweringState:
self.has_platform_index_argument = False
self.dim_vars = dim_vars
@dataclasses.dataclass(frozen=True)
class LoweringParameters:
# A mapping between primitives and user-defined LoweringRules.
@ -429,7 +428,8 @@ class LoweringParameters:
# The current lowering platforms, a non-empty tuple containing some of
# 'cpu', 'cuda', 'rocm', 'tpu'. If the tuple has multiple entries we are
# doing multi-platform lowering, otherwise it can specify cross-platform
# lowering. The value None specify default lowering platform.
# lowering. The value None specifies default lowering platform, for the
# platform specified by `ModuleContext.platform`.
# This is used only in export and jax2tf.
platforms: tuple[str, ...] | None = None
@ -446,6 +446,10 @@ class LoweringParameters:
else:
return None
@property
def is_multi_platform(self) -> bool:
return self.platforms is not None and len(self.platforms) > 1
@dataclasses.dataclass
class ModuleContext:
@ -689,7 +693,7 @@ def lower_jaxpr_to_module(
*,
ordered_effects: list[core.Effect],
backend_or_name: str | xb.XlaBackend | None,
platform: str | tuple[str, ...],
platform: str,
axis_context: AxisContext,
name_stack: source_info_util.NameStack,
donated_args: Sequence[bool],
@ -708,14 +712,14 @@ def lower_jaxpr_to_module(
Handles the quirks of the argument/return value passing conventions of the
runtime.
"""
if lowering_parameters.platforms is not None:
# Only for multi-platform lowering
# TODO(necula): for now we lower only for the first platform
platform = lowering_parameters.platforms[0]
platforms: tuple[str, ...]
platform = xb.canonicalize_platform(platform)
if lowering_parameters.is_multi_platform:
platforms = tuple(map(xb.canonicalize_platform,
lowering_parameters.platforms)) # type: ignore
else:
platforms = (platform,)
platform = xb.canonicalize_platform(platform) # type: ignore
if not xb.is_known_platform(platform):
raise ValueError(f"Unknown platform {platform}")
input_output_aliases = None
in_avals = (jaxpr.in_avals if arg_shardings is None else
map(sharded_aval, jaxpr.in_avals, arg_shardings))
@ -730,7 +734,14 @@ def lower_jaxpr_to_module(
result_memory_kinds = (map(_get_mem_kind, result_shardings)
if result_shardings is not None else None)
if platform in _platforms_with_donation:
platforms_with_donation = [p for p in platforms
if p in _platforms_with_donation]
if platforms_with_donation:
if len(platforms_with_donation) != len(platforms):
raise NotImplementedError(
"In multi-platform lowering either all or no lowering platforms "
f"should support donation. Lowering for {platforms} of which "
f"only {platforms_with_donation} support donation")
input_output_aliases, donated_args = _set_up_aliases(
in_avals, out_avals, donated_args, arg_memory_kinds, result_memory_kinds)
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
@ -739,8 +750,8 @@ def lower_jaxpr_to_module(
if any(donated_args):
unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d]
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
if platform not in _platforms_with_donation:
msg = f"Donation is not implemented for {platform}.\n{msg}"
if not platforms_with_donation:
msg = f"Donation is not implemented for {platforms}.\n{msg}"
if unused_donations:
warnings.warn("Some donated buffers were not usable:"
f" {', '.join(unused_donations)}.\n{msg}")
@ -805,9 +816,17 @@ def lower_jaxpr_to_module(
raise ValueError(
f"Cannot lower jaxpr with verifier errors: {module_string}")
except ir.MLIRError as e:
module_string = module_to_string(ctx.module)
raise ValueError(
f"Cannot lower jaxpr with verifier errors: {module_string}") from e
msg_lines = ["Cannot lower jaxpr with verifier errors:"]
def emit_diagnostic_info(d):
msg_lines.append(f"\t{d.message}")
msg_lines.append(f"\t\tat {d.location}")
for n in d.notes:
emit_diagnostic_info(n)
for d in e.error_diagnostics:
emit_diagnostic_info(d)
msg_lines.append("Module string:")
msg_lines.append(module_to_string(ctx.module))
raise ValueError("\n".join(msg_lines)) from e
return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks,
ctx.shape_poly_state)
@ -1314,7 +1333,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
assert node is not None
env[v] = tuple(node)
def get_lowering(primitive: core.Primitive) -> LoweringRule | None:
def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None:
if ctx.lowering_parameters.override_lowering_rules is None:
return None
for p, rule in ctx.lowering_parameters.override_lowering_rules:
@ -1339,21 +1358,46 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
loc = _source_info_to_location(eqn.primitive, eqn.params, source_info,
ctx.name_stack)
with source_info_util.user_context(eqn.source_info.traceback), loc:
override_rule = get_lowering(eqn.primitive)
if override_rule is not None:
rule = override_rule
elif eqn.primitive in _platform_specific_lowerings[ctx.platform]:
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive]
elif eqn.primitive in xla._backend_specific_translations[ctx.platform]:
rule = xla_fallback_lowering(eqn.primitive)
elif eqn.primitive in _lowerings:
rule = _lowerings[eqn.primitive]
elif eqn.primitive in xla._translations:
rule = xla_fallback_lowering(eqn.primitive)
override_rule = get_override_lowering_rule(eqn.primitive)
if not ctx.lowering_parameters.is_multi_platform:
# Classic, single-platform lowering
# TODO(necula): unify the code paths when multi-platform is finished
if override_rule is not None:
rule = override_rule
elif eqn.primitive in _platform_specific_lowerings[ctx.platform]:
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive]
elif eqn.primitive in xla._backend_specific_translations[ctx.platform]:
rule = xla_fallback_lowering(eqn.primitive)
elif eqn.primitive in _lowerings:
rule = _lowerings[eqn.primitive]
elif eqn.primitive in xla._translations:
rule = xla_fallback_lowering(eqn.primitive)
else:
raise NotImplementedError(
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
f"found for platform {ctx.platform}")
else:
raise NotImplementedError(
f"MLIR translation rule for primitive '{eqn.primitive.name}' not "
f"found for platform {ctx.platform}")
rules: list[MultiPlatformLoweringRule]
# See mlir.lower_multi_platform for the `rules` format
if override_rule is not None:
rules = [(None, override_rule)]
else:
# First the platform-specific rules
rules = []
for p in ctx.lowering_parameters.platforms: # type: ignore
if eqn.primitive in _platform_specific_lowerings[p]:
rules.append(
([p], _platform_specific_lowerings[p][eqn.primitive]))
elif eqn.primitive in xla._backend_specific_translations[p]:
rules.append(
([p], xla_fallback_lowering(eqn.primitive)))
# Now the catch-all rules
if eqn.primitive in _lowerings:
rules.append(
(None, _lowerings[eqn.primitive])) # type: ignore
elif eqn.primitive in xla._translations:
rules.append(
(None, xla_fallback_lowering(eqn.primitive))) # type: ignore
eqn_ctx = ctx.replace(name_stack=source_info.name_stack)
effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
@ -1368,8 +1412,15 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
for a in avals_in if type(a) is core.DShapedArray
for d in a.shape if type(d) is core.Var}
rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
**eqn.params)
rule_inputs = map(_unwrap_singleton_ir_values, in_nodes)
if not ctx.lowering_parameters.is_multi_platform:
# Classic, single-platform lowering
ans = rule(rule_ctx, *rule_inputs, **eqn.params)
else:
ans = lower_multi_platform(rule_ctx, str(eqn), rules,
*rule_inputs, **eqn.params)
if effects:
# If there were ordered effects in the primitive, there should be output
# tokens we need for subsequent ordered effects.
@ -1399,6 +1450,115 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
core.clean_up_dead_vars(eqn, env, last_used)
return map(read, jaxpr.outvars), tokens
# See docstring for lower_multi_platform.
MultiPlatformLoweringRule = tuple[Optional[Sequence[str]], Callable]
def lower_multi_platform(ctx: LoweringRuleContext,
description: str,
rules: Sequence[MultiPlatformLoweringRule],
*rule_args: ir.Value,
**rule_kwargs) -> ir.Value:
"""Emits single- or multi-platform code for a primitive.
For example, given
ctx.module_context.lowering_parameters.platforms = ("cpu", "gpu", "tpu")
and
rules = [(["tpu", "cpu"], rule0),
(None, rule1)
emits:
rule_idx = case current_platform_idx:
0: return 0 # cpu rule index
1: return 1 # gpu rule index
2: return 0 # tpu rule index
output = case rule_idx
0: return rule0(*rule_args, **rule_kwargs)
1: return rule1(*rule_args, **rule_kwargs)
If the primitive has a single lowering rule for all platforms of interest,
skips the conditionals and emits the same code as for classic single-platform
lowering.
Args:
ctx: lowering context.
description: a string to include in error messages.
rules: a sequence of per-platform rules. Each entry is a tuple, with the
first element specifying the platforms, either a sequence of applicable
platform names (maybe empty), or None to denote a default entry to use
when no other entry applies. The second element of the tuple is a
lowering rule, i.e., a function to invoke with a
LoweringRuleContext (a sub-context of `ctx`),
and `*rule_args` and `**rule_kwargs`.
rule_args: the args of the lowering rules.
rule_kwargs: the kwargs of the lowering rules.
"""
assert isinstance(ctx.module_context.lowering_parameters.platforms, tuple)
platforms = ctx.module_context.lowering_parameters.platforms
platforms_with_specific_rules = util.flatten(
[ps for ps, _ in rules if ps is not None])
platforms_with_default_rule = [p for p in platforms
if p not in platforms_with_specific_rules]
kept_rules: list[MultiPlatformLoweringRule] = [] # Only the rules for platforms of interest
platform_to_kept_rules_idx: dict[str, int] = {}
for ps, r in rules:
rule_index = len(kept_rules)
if ps is not None:
# Keep only rules that mention the platforms of interest
interesting_ps = [p for p in platforms if p in ps]
if interesting_ps:
for p in interesting_ps:
assert p not in platform_to_kept_rules_idx
platform_to_kept_rules_idx[p] = rule_index
kept_rules.append((interesting_ps, r))
elif platforms_with_default_rule:
for p in platforms_with_default_rule:
assert p not in platform_to_kept_rules_idx
platform_to_kept_rules_idx[p] = rule_index
kept_rules.append((platforms_with_default_rule, r))
platforms_without_rules = [p for p in platforms
if p not in platform_to_kept_rules_idx]
if platforms_without_rules:
raise ValueError(
f"MLIR translation rule for primitive '{description}' not "
f"found for platforms {platforms_without_rules}")
assert kept_rules
# Maybe there is a single rule left, just apply the rule, no conditionals.
if len(kept_rules) == 1:
return kept_rules[0][1](ctx, *rule_args, **rule_kwargs)
assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules)
assert len(ctx.dim_var_values) >= 1, "Must have a platform_index variable"
# The first dim_var_values is the platform index
current_platform_idx = ctx.dim_var_values[0]
# Compute the rule index based on the current platform
i32_type = aval_to_ir_types(core.ShapedArray((), dtype=np.int32))[0]
if current_platform_idx.type != i32_type:
current_platform_idx = hlo.ConvertOp(i32_type, current_platform_idx)
rule_idx_op = hlo.CaseOp([i32_type],
index=current_platform_idx,
num_branches=len(platforms))
for i, p in enumerate(platforms):
branch = rule_idx_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
hlo.ReturnOp(ir_constants(np.int32(platform_to_kept_rules_idx[p])))
case_op = hlo.CaseOp(util.flatten(map(aval_to_ir_types, ctx.avals_out)),
index=rule_idx_op,
num_branches=len(kept_rules))
for i, (_, rule) in enumerate(kept_rules):
inner_ctx = ctx.replace()
branch = case_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
output = rule(inner_ctx, *rule_args, **rule_kwargs)
try:
out_nodes = map(wrap_singleton_ir_values, output)
except TypeError as e:
raise ValueError("Output of translation rule must be iterable: "
f"{description}, got output {output}") from e
hlo.ReturnOp(util.flatten(out_nodes))
return case_op.results
def _ir_consts(consts):
unique_consts = {id(const): const for const in consts}

View File

@ -2219,7 +2219,7 @@ def lower_mesh_computation(
closed_jaxpr,
ordered_effects=ordered_effects,
backend_or_name=backend,
platform=lowering_parameters.platforms or backend.platform,
platform=lowering_parameters.override_platform or backend.platform,
axis_context=axis_ctx,
name_stack=name_stack,
donated_args=donated_invars,

View File

@ -16,7 +16,7 @@ import functools
import logging
import math
import re
from typing import Optional
from typing import Optional, Sequence
import unittest
from absl.testing import absltest
@ -38,6 +38,43 @@ import numpy as np
config.parse_flags_with_absl()
# A primitive for testing multi-platform lowering. Takes one argument and
# adds a different value to it: cpu=2., tpu=3., cuda=.4, rocm=5.
_testing_multi_platform_p = core.Primitive("testing_multi_platform")
_testing_multi_platform_to_add = dict(cpu=2., tpu=3., cuda=4., rocm=5.)
@_testing_multi_platform_p.def_abstract_eval
def _testing_multi_platform_abstract_eval(xaval: core.AbstractValue):
assert xaval.dtype == np.float32 # type: ignore
return xaval
def _testing_multi_platform_lowering(ctx: mlir.LoweringRuleContext,
x: mlir.Value,
*,
platform: str) -> Sequence[mlir.Value]:
to_add = _testing_multi_platform_to_add[platform]
to_add_value = mlir.broadcast_in_dim(ctx,
mlir.ir_constant(np.float32(to_add)),
ctx.avals_in[0],
broadcast_dimensions=())
return mlir.hlo.AddOp(x, to_add_value).results
# Register a default rule for cuda, to test the default-platform rule selection.
mlir.register_lowering(_testing_multi_platform_p,
functools.partial(_testing_multi_platform_lowering,
platform="cuda"))
for platform in ["cpu", "tpu", "rocm"]:
mlir.register_lowering(_testing_multi_platform_p,
functools.partial(_testing_multi_platform_lowering,
platform=platform),
platform=platform)
def _testing_multi_platform_func(x):
return _testing_multi_platform_p.bind(x)
def _testing_multi_platform_fun_expected(x):
return x + _testing_multi_platform_to_add[xb.canonicalize_platform(jtu.device_under_test())]
class JaxExportTest(jtu.JaxTestCase):
@ -568,54 +605,50 @@ class JaxExportTest(jtu.JaxTestCase):
# The export is not applicable to GPU
raise unittest.SkipTest("Not intended for running on GPU")
x = np.arange(5, dtype=np.float32)
# TODO: use a function with different behavior for different platforms
exp = export.export(jnp.sin,
exp = export.export(_testing_multi_platform_func,
lowering_platforms=('cpu', 'tpu'))(x)
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu'))
module_str = str(exp.mlir_module())
platform_index = re.findall(
r"(%arg\d):\s*tensor<i..>\s*{jax.platform_index = true}",
module_str)
self.assertEqual(["%arg0"], platform_index,
f"Found {platform_index} in {module_str}")
expected_main_re = (
r"@main\("
r"%arg0: tensor<i..> {jax.platform_index = true}.*, "
r"%arg1: tensor<5xf32>.* ->")
self.assertRegex(module_str, expected_main_re)
res = export.call_exported(exp)(x)
self.assertAllClose(res, np.sin(x))
self.assertAllClose(res, _testing_multi_platform_fun_expected(x))
def test_multi_platform_nested(self):
if jtu.test_device_matches(["tpu"]):
# The outer export is not applicable to TPU
raise unittest.SkipTest("Not intended for running on TPU")
x = np.arange(5, dtype=np.float32)
# TODO: use a function with different behavior for different platforms
exp = export.export(jnp.sin,
lowering_platforms=('cpu', 'tpu', 'cuda'))(x)
exp = export.export(_testing_multi_platform_func,
lowering_platforms=('cpu', 'tpu', 'cuda'))(x)
self.assertEqual(exp.lowering_platforms, ('cpu', 'tpu', 'cuda'))
# Now serialize the call to the exported using a different sequence of
# lowering platforms, but included in the lowering platforms for the
# nested exported.
# TODO: improve this test once we implement true multi-platform lowering
exp2 = export.export(export.call_exported(exp),
lowering_platforms=('cpu', 'cuda'))(x)
lowering_platforms=('cpu', 'cuda'))(x)
res2 = export.call_exported(exp2)(x)
self.assertAllClose(res2, np.sin(x))
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x))
def test_multi_platform_and_poly(self):
if jtu.test_device_matches(["gpu"]):
# The export is not applicable to GPU
raise unittest.SkipTest("Not intended for running on GPU")
# TODO: use a function with different behavior for different platforms
exp = export.export(lambda x: jnp.reshape(jnp.sin(x), (-1,)),
lowering_platforms=('cpu', 'tpu'))(
exp = export.export(lambda x: jnp.reshape(_testing_multi_platform_func(x), (-1,)),
lowering_platforms=('cpu', 'tpu'))(
export.poly_spec((5, 6), np.float32, "b1, b2")
)
x = np.arange(12, dtype=np.float32).reshape((3, 4))
res = export.call_exported(exp)(x)
self.assertAllClose(res, np.sin(x).reshape((-1,)))
self.assertAllClose(res, _testing_multi_platform_fun_expected(x).reshape((-1,)))
# Now serialize the call to the exported
exp2 = export.export(export.call_exported(exp))(x)
res2 = export.call_exported(exp2)(x)
self.assertAllClose(res2, np.sin(x).reshape((-1,)))
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x).reshape((-1,)))
if __name__ == "__main__":