mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pallas] Add support for cross-platform lowering
When implementing this I have discovered that the multi-platform lowering support does not handle the case when the lowering rule for a platform invoke tracing (via `mlir.lower_fun`) and that tracing encounters a primitive that has lowering rules only for a particular platform. To support this, I have added the `LoweringRuleContext.platforms` to override `ModuleContext.platforms` with a potentially narrower set of lowering platforms. Added a test for this scenario.
This commit is contained in:
parent
9b68873436
commit
97db0e758d
@ -585,6 +585,8 @@ class ModuleContext:
|
||||
ip: ir.InsertionPoint
|
||||
symbol_table: ir.SymbolTable
|
||||
backend_or_name: str | xb.XlaBackend | None
|
||||
# The lowering platforms for the module. Can be more than one only when
|
||||
# exporting.
|
||||
platforms: Sequence[str]
|
||||
axis_context: AxisContext
|
||||
keepalives: list[Any]
|
||||
@ -689,6 +691,9 @@ class LoweringRuleContext:
|
||||
# module_context.shape_poly_state.dim_vars
|
||||
dim_var_values: Sequence[ir.Value] = ()
|
||||
compute_type: str | None = None
|
||||
# Override module_context.platforms if not None. Used during multi-platform
|
||||
# lowering, when in a scope with a subset of the module_context.platforms.
|
||||
platforms: Sequence[str] | None = None
|
||||
|
||||
def set_tokens_out(self, tokens_out: TokenSet):
|
||||
assert self.tokens_out is None, 'Should only set `tokens_out` once.'
|
||||
@ -1662,7 +1667,7 @@ def lower_per_platform(ctx: LoweringRuleContext,
|
||||
rule_args: the args of the lowering rules.
|
||||
rule_kwargs: the kwargs of the lowering rules.
|
||||
"""
|
||||
platforms: Sequence[str] = ctx.module_context.platforms
|
||||
platforms: Sequence[str] = ctx.platforms or ctx.module_context.platforms
|
||||
# Special case the common case (single-platform lowering)
|
||||
if len(platforms) == 1:
|
||||
rule = platform_rules.get(platforms[0], default_rule)
|
||||
@ -1723,7 +1728,10 @@ def lower_per_platform(ctx: LoweringRuleContext,
|
||||
index=rule_idx_op,
|
||||
num_branches=len(kept_rules))
|
||||
for i, rule in enumerate(kept_rules):
|
||||
inner_ctx = ctx.replace()
|
||||
platforms_for_this_rule = [p
|
||||
for p, rule_idx in platform_to_kept_rules_idx.items()
|
||||
if rule_idx == i]
|
||||
inner_ctx = ctx.replace(platforms=platforms_for_this_rule)
|
||||
branch = case_op.regions[i].blocks.append()
|
||||
with ir.InsertionPoint(branch):
|
||||
output = rule(inner_ctx, *rule_args, **rule_kwargs)
|
||||
@ -1764,7 +1772,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
||||
|
||||
The returned function does not use `avals_out`, so callers may pass any value
|
||||
as `avals_out`."""
|
||||
def f_lowered(ctx, *args, **params):
|
||||
def f_lowered(ctx: LoweringRuleContext, *args, **params):
|
||||
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
||||
wrapped_fun = lu.wrap_init(f, params)
|
||||
|
||||
@ -1774,11 +1782,12 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
||||
# case, we need to form a jaxpr with leading binders for those axis size
|
||||
# arguments (by computing an InputType and using trace_to_jaxpr_dynamic2),
|
||||
# and we need to call jaxpr_subcomp with these arguments made explicit.
|
||||
assert ctx.axis_size_env is not None
|
||||
args = (*ctx.axis_size_env.values(), *args)
|
||||
idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)}
|
||||
i32_aval = core.ShapedArray((), np.dtype('int32'))
|
||||
implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env)
|
||||
explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape))
|
||||
explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) # type: ignore
|
||||
if type(a) is core.DShapedArray else a, True)
|
||||
for a in ctx.avals_in]
|
||||
wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args))
|
||||
@ -1787,8 +1796,12 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
|
||||
|
||||
if ctx.platforms is not None:
|
||||
sub_context = ctx.module_context.replace(platforms=ctx.platforms)
|
||||
else:
|
||||
sub_context = ctx.module_context
|
||||
out, tokens = jaxpr_subcomp(
|
||||
ctx.module_context, jaxpr, ctx.name_stack, ctx.tokens_in,
|
||||
sub_context, jaxpr, ctx.name_stack, ctx.tokens_in,
|
||||
_ir_consts(consts), *map(wrap_singleton_ir_values, args),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
ctx.set_tokens_out(tokens)
|
||||
|
@ -96,7 +96,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
return mosaic.as_tpu_kernel(
|
||||
mosaic_module,
|
||||
out_avals,
|
||||
backend=ctx.module_context.backend,
|
||||
backend="tpu",
|
||||
kernel_name=name,
|
||||
cost_estimate=mosaic_params.get("cost_estimate"),
|
||||
vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"),
|
||||
|
@ -720,38 +720,48 @@ def _pallas_call_lowering(
|
||||
impl = partial(_pallas_call_impl, **params, interpret=True)
|
||||
return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes)
|
||||
|
||||
try:
|
||||
[platform] = ctx.module_context.platforms
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"Can only lower pallas_call on a single platform."
|
||||
) from None
|
||||
|
||||
if platform == "cpu":
|
||||
def cpu_lowering(ctx: mlir.LoweringRuleContext,
|
||||
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
|
||||
**params):
|
||||
raise ValueError("Only interpret mode is supported on CPU backend.")
|
||||
elif platform == "cuda" or platform == "rocm":
|
||||
|
||||
def tpu_lowering(ctx: mlir.LoweringRuleContext,
|
||||
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
|
||||
**params):
|
||||
try:
|
||||
from jax._src.pallas.mosaic import pallas_call_registration
|
||||
except ImportError:
|
||||
raise _unsupported_lowering_error("tpu")
|
||||
else:
|
||||
return pallas_call_registration.pallas_call_tpu_lowering_rule(
|
||||
ctx, *in_nodes, **params
|
||||
)
|
||||
|
||||
def gpu_lowering(ctx: mlir.LoweringRuleContext,
|
||||
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
|
||||
**params):
|
||||
try:
|
||||
if _PALLAS_USE_MOSAIC_GPU.value:
|
||||
from jax._src.pallas.mosaic_gpu import pallas_call_registration
|
||||
else:
|
||||
from jax._src.pallas.triton import pallas_call_registration # type: ignore
|
||||
except ImportError:
|
||||
pass
|
||||
raise _unsupported_lowering_error("gpu")
|
||||
else:
|
||||
return pallas_call_registration.pallas_call_lowering(
|
||||
ctx, *in_nodes, interpret=interpret, **params
|
||||
)
|
||||
elif platform == "tpu":
|
||||
try:
|
||||
from jax._src.pallas.mosaic import pallas_call_registration # type: ignore
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
return pallas_call_registration.pallas_call_tpu_lowering_rule(
|
||||
ctx, *in_nodes, interpret=interpret, **params
|
||||
ctx, *in_nodes, **params
|
||||
)
|
||||
|
||||
raise _unsupported_lowering_error(platform)
|
||||
return mlir.lower_per_platform(ctx, "pallas_call",
|
||||
dict(cpu=cpu_lowering,
|
||||
tpu=tpu_lowering,
|
||||
cuda=gpu_lowering,
|
||||
rocm=gpu_lowering),
|
||||
None, # default_rule
|
||||
effects.no_effects,
|
||||
*in_nodes,
|
||||
interpret=interpret,
|
||||
**params)
|
||||
|
||||
|
||||
mlir.register_lowering(pallas_call_p, _pallas_call_lowering)
|
||||
|
@ -76,9 +76,8 @@ def pallas_call_lowering(
|
||||
)
|
||||
triton_params = compiler_params.get("triton", compiler_params)
|
||||
num_warps = triton_params.pop("num_warps", 4)
|
||||
if len(ctx.module_context.platforms) > 1:
|
||||
raise NotImplementedError("multi-platform lowering for Pallas kernels")
|
||||
if ctx.module_context.platforms[0] == "rocm":
|
||||
[lowering_platform] = ctx.platforms or ctx.module_context.platforms
|
||||
if lowering_platform == "rocm":
|
||||
num_stages = triton_params.pop("num_stages", 1)
|
||||
else:
|
||||
num_stages = triton_params.pop("num_stages", 3)
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
@ -1369,6 +1370,64 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
res2 = exp2.call(x)
|
||||
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x))
|
||||
|
||||
def test_multi_platform_mlir_lower_fun_with_platform_specific_primitives(self):
|
||||
# A primitive with multiple lowering rules, which themselves involve
|
||||
# tracing primitives with per-platform rules, using mlir.lower_fun.
|
||||
# This situation arises for Pallas lowering.
|
||||
def times_n_lowering(n: int, ctx: mlir.LoweringRuleContext,
|
||||
x: mlir.ir.Value) -> Sequence[mlir.ir.Value]:
|
||||
# Lowering n * x
|
||||
res = x
|
||||
for i in range(n - 1):
|
||||
res = mlir.hlo.AddOp(res, x)
|
||||
return res.results
|
||||
|
||||
times_2 = core.Primitive("__testing_times_2") # x2 for cpu
|
||||
times_2.def_abstract_eval(lambda x: x)
|
||||
# Define lowering rules only for the relevant platforms, ensure there
|
||||
# is no error about missing lowering rules
|
||||
mlir.register_lowering(times_2, functools.partial(times_n_lowering, 2),
|
||||
"cpu")
|
||||
|
||||
times_3 = core.Primitive("__testing_times_3") # x3 for cuda
|
||||
times_3.def_abstract_eval(lambda x: x)
|
||||
mlir.register_lowering(times_3, functools.partial(times_n_lowering, 3),
|
||||
"cuda")
|
||||
|
||||
times_4 = core.Primitive("__testing_times_4") # x4 for tpu
|
||||
times_4.def_abstract_eval(lambda x: x)
|
||||
mlir.register_lowering(times_4, functools.partial(times_n_lowering, 4),
|
||||
"tpu")
|
||||
|
||||
times_2_or_3 = core.Primitive("__testing_times_2_or_3") # x2 for cpu, x3 for cuda
|
||||
times_2_or_3.def_abstract_eval(lambda x: x)
|
||||
mlir.register_lowering(times_2_or_3,
|
||||
mlir.lower_fun(times_2.bind,
|
||||
multiple_results=False), "cpu")
|
||||
mlir.register_lowering(times_2_or_3,
|
||||
mlir.lower_fun(times_3.bind,
|
||||
multiple_results=False), "cuda")
|
||||
|
||||
times_2_or_3_or_4 = core.Primitive("__testing_times_2_or_3_or_4") # x2 for cpu, x3 for cuda, x4 for tpu
|
||||
times_2_or_3_or_4.def_abstract_eval(lambda x: x)
|
||||
times_2_or_3_or_4_lowering_cpu_cuda = mlir.lower_fun(times_2_or_3.bind,
|
||||
multiple_results=False)
|
||||
for platform in ["cpu", "cuda"]:
|
||||
mlir.register_lowering(times_2_or_3_or_4,
|
||||
times_2_or_3_or_4_lowering_cpu_cuda,
|
||||
platform)
|
||||
mlir.register_lowering(times_2_or_3_or_4, mlir.lower_fun(times_4.bind,
|
||||
multiple_results=False),
|
||||
"tpu")
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return times_2_or_3_or_4.bind(x)
|
||||
x = np.float32(42.)
|
||||
exp = export.export(f, lowering_platforms=["cpu", "cuda", "tpu"])(x)
|
||||
expected = x * np.float32(dict(cpu=2, gpu=3, tpu=4)[jtu.device_under_test()])
|
||||
self.assertAllClose(exp.call(x), expected)
|
||||
|
||||
def test_multi_platform_and_poly(self):
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
# The export is not applicable to GPU
|
||||
|
@ -43,11 +43,12 @@ class ExportTest(jtu.JaxTestCase):
|
||||
a = np.arange(8)
|
||||
exp = export.export(
|
||||
add_vectors,
|
||||
# TODO(necula): Make this test work on GPU also
|
||||
lowering_platforms=["tpu"],
|
||||
lowering_platforms=["tpu", "cuda"],
|
||||
)(a, a)
|
||||
|
||||
if jtu.device_under_test() == "tpu":
|
||||
if (jtu.device_under_test() == "tpu" or
|
||||
(jtu.device_under_test() == "gpu" and
|
||||
jtu.is_cuda_compute_capability_at_least("8.0"))):
|
||||
res = export.call(exp)(a, a)
|
||||
self.assertAllClose(res, a + a)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user