diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index bf09b3088..ba429dfe4 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -585,6 +585,8 @@ class ModuleContext: ip: ir.InsertionPoint symbol_table: ir.SymbolTable backend_or_name: str | xb.XlaBackend | None + # The lowering platforms for the module. Can be more than one only when + # exporting. platforms: Sequence[str] axis_context: AxisContext keepalives: list[Any] @@ -689,6 +691,9 @@ class LoweringRuleContext: # module_context.shape_poly_state.dim_vars dim_var_values: Sequence[ir.Value] = () compute_type: str | None = None + # Override module_context.platforms if not None. Used during multi-platform + # lowering, when in a scope with a subset of the module_context.platforms. + platforms: Sequence[str] | None = None def set_tokens_out(self, tokens_out: TokenSet): assert self.tokens_out is None, 'Should only set `tokens_out` once.' @@ -1662,7 +1667,7 @@ def lower_per_platform(ctx: LoweringRuleContext, rule_args: the args of the lowering rules. rule_kwargs: the kwargs of the lowering rules. """ - platforms: Sequence[str] = ctx.module_context.platforms + platforms: Sequence[str] = ctx.platforms or ctx.module_context.platforms # Special case the common case (single-platform lowering) if len(platforms) == 1: rule = platform_rules.get(platforms[0], default_rule) @@ -1723,7 +1728,10 @@ def lower_per_platform(ctx: LoweringRuleContext, index=rule_idx_op, num_branches=len(kept_rules)) for i, rule in enumerate(kept_rules): - inner_ctx = ctx.replace() + platforms_for_this_rule = [p + for p, rule_idx in platform_to_kept_rules_idx.items() + if rule_idx == i] + inner_ctx = ctx.replace(platforms=platforms_for_this_rule) branch = case_op.regions[i].blocks.append() with ir.InsertionPoint(branch): output = rule(inner_ctx, *rule_args, **rule_kwargs) @@ -1764,7 +1772,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable: The returned function does not use `avals_out`, so callers may pass any value as `avals_out`.""" - def f_lowered(ctx, *args, **params): + def f_lowered(ctx: LoweringRuleContext, *args, **params): f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) wrapped_fun = lu.wrap_init(f, params) @@ -1774,11 +1782,12 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable: # case, we need to form a jaxpr with leading binders for those axis size # arguments (by computing an InputType and using trace_to_jaxpr_dynamic2), # and we need to call jaxpr_subcomp with these arguments made explicit. + assert ctx.axis_size_env is not None args = (*ctx.axis_size_env.values(), *args) idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)} i32_aval = core.ShapedArray((), np.dtype('int32')) implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env) - explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) + explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) # type: ignore if type(a) is core.DShapedArray else a, True) for a in ctx.avals_in] wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args)) @@ -1787,8 +1796,12 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable: jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) # TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out? + if ctx.platforms is not None: + sub_context = ctx.module_context.replace(platforms=ctx.platforms) + else: + sub_context = ctx.module_context out, tokens = jaxpr_subcomp( - ctx.module_context, jaxpr, ctx.name_stack, ctx.tokens_in, + sub_context, jaxpr, ctx.name_stack, ctx.tokens_in, _ir_consts(consts), *map(wrap_singleton_ir_values, args), dim_var_values=ctx.dim_var_values) ctx.set_tokens_out(tokens) diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index 50cfff052..c95a146f2 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -96,7 +96,7 @@ def pallas_call_tpu_lowering_rule( return mosaic.as_tpu_kernel( mosaic_module, out_avals, - backend=ctx.module_context.backend, + backend="tpu", kernel_name=name, cost_estimate=mosaic_params.get("cost_estimate"), vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"), diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index a4436cd1d..d18e31913 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -720,38 +720,48 @@ def _pallas_call_lowering( impl = partial(_pallas_call_impl, **params, interpret=True) return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes) - try: - [platform] = ctx.module_context.platforms - except ValueError: - raise ValueError( - "Can only lower pallas_call on a single platform." - ) from None - - if platform == "cpu": + def cpu_lowering(ctx: mlir.LoweringRuleContext, + *in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value], + **params): raise ValueError("Only interpret mode is supported on CPU backend.") - elif platform == "cuda" or platform == "rocm": + + def tpu_lowering(ctx: mlir.LoweringRuleContext, + *in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value], + **params): + try: + from jax._src.pallas.mosaic import pallas_call_registration + except ImportError: + raise _unsupported_lowering_error("tpu") + else: + return pallas_call_registration.pallas_call_tpu_lowering_rule( + ctx, *in_nodes, **params + ) + + def gpu_lowering(ctx: mlir.LoweringRuleContext, + *in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value], + **params): try: if _PALLAS_USE_MOSAIC_GPU.value: from jax._src.pallas.mosaic_gpu import pallas_call_registration else: from jax._src.pallas.triton import pallas_call_registration # type: ignore except ImportError: - pass + raise _unsupported_lowering_error("gpu") else: return pallas_call_registration.pallas_call_lowering( - ctx, *in_nodes, interpret=interpret, **params - ) - elif platform == "tpu": - try: - from jax._src.pallas.mosaic import pallas_call_registration # type: ignore - except ImportError: - pass - else: - return pallas_call_registration.pallas_call_tpu_lowering_rule( - ctx, *in_nodes, interpret=interpret, **params + ctx, *in_nodes, **params ) - raise _unsupported_lowering_error(platform) + return mlir.lower_per_platform(ctx, "pallas_call", + dict(cpu=cpu_lowering, + tpu=tpu_lowering, + cuda=gpu_lowering, + rocm=gpu_lowering), + None, # default_rule + effects.no_effects, + *in_nodes, + interpret=interpret, + **params) mlir.register_lowering(pallas_call_p, _pallas_call_lowering) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 28d59af02..2158167f7 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -76,9 +76,8 @@ def pallas_call_lowering( ) triton_params = compiler_params.get("triton", compiler_params) num_warps = triton_params.pop("num_warps", 4) - if len(ctx.module_context.platforms) > 1: - raise NotImplementedError("multi-platform lowering for Pallas kernels") - if ctx.module_context.platforms[0] == "rocm": + [lowering_platform] = ctx.platforms or ctx.module_context.platforms + if lowering_platform == "rocm": num_stages = triton_params.pop("num_stages", 1) else: num_stages = triton_params.pop("num_stages", 3) diff --git a/tests/export_test.py b/tests/export_test.py index ef55e1a62..cb2d7bdc4 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +from collections.abc import Sequence import contextlib import dataclasses import functools @@ -1369,6 +1370,64 @@ class JaxExportTest(jtu.JaxTestCase): res2 = exp2.call(x) self.assertAllClose(res2, _testing_multi_platform_fun_expected(x)) + def test_multi_platform_mlir_lower_fun_with_platform_specific_primitives(self): + # A primitive with multiple lowering rules, which themselves involve + # tracing primitives with per-platform rules, using mlir.lower_fun. + # This situation arises for Pallas lowering. + def times_n_lowering(n: int, ctx: mlir.LoweringRuleContext, + x: mlir.ir.Value) -> Sequence[mlir.ir.Value]: + # Lowering n * x + res = x + for i in range(n - 1): + res = mlir.hlo.AddOp(res, x) + return res.results + + times_2 = core.Primitive("__testing_times_2") # x2 for cpu + times_2.def_abstract_eval(lambda x: x) + # Define lowering rules only for the relevant platforms, ensure there + # is no error about missing lowering rules + mlir.register_lowering(times_2, functools.partial(times_n_lowering, 2), + "cpu") + + times_3 = core.Primitive("__testing_times_3") # x3 for cuda + times_3.def_abstract_eval(lambda x: x) + mlir.register_lowering(times_3, functools.partial(times_n_lowering, 3), + "cuda") + + times_4 = core.Primitive("__testing_times_4") # x4 for tpu + times_4.def_abstract_eval(lambda x: x) + mlir.register_lowering(times_4, functools.partial(times_n_lowering, 4), + "tpu") + + times_2_or_3 = core.Primitive("__testing_times_2_or_3") # x2 for cpu, x3 for cuda + times_2_or_3.def_abstract_eval(lambda x: x) + mlir.register_lowering(times_2_or_3, + mlir.lower_fun(times_2.bind, + multiple_results=False), "cpu") + mlir.register_lowering(times_2_or_3, + mlir.lower_fun(times_3.bind, + multiple_results=False), "cuda") + + times_2_or_3_or_4 = core.Primitive("__testing_times_2_or_3_or_4") # x2 for cpu, x3 for cuda, x4 for tpu + times_2_or_3_or_4.def_abstract_eval(lambda x: x) + times_2_or_3_or_4_lowering_cpu_cuda = mlir.lower_fun(times_2_or_3.bind, + multiple_results=False) + for platform in ["cpu", "cuda"]: + mlir.register_lowering(times_2_or_3_or_4, + times_2_or_3_or_4_lowering_cpu_cuda, + platform) + mlir.register_lowering(times_2_or_3_or_4, mlir.lower_fun(times_4.bind, + multiple_results=False), + "tpu") + + @jax.jit + def f(x): + return times_2_or_3_or_4.bind(x) + x = np.float32(42.) + exp = export.export(f, lowering_platforms=["cpu", "cuda", "tpu"])(x) + expected = x * np.float32(dict(cpu=2, gpu=3, tpu=4)[jtu.device_under_test()]) + self.assertAllClose(exp.call(x), expected) + def test_multi_platform_and_poly(self): if jtu.test_device_matches(["gpu"]): # The export is not applicable to GPU diff --git a/tests/pallas/export_pallas_test.py b/tests/pallas/export_pallas_test.py index 1c46a226e..c572357c6 100644 --- a/tests/pallas/export_pallas_test.py +++ b/tests/pallas/export_pallas_test.py @@ -43,11 +43,12 @@ class ExportTest(jtu.JaxTestCase): a = np.arange(8) exp = export.export( add_vectors, - # TODO(necula): Make this test work on GPU also - lowering_platforms=["tpu"], + lowering_platforms=["tpu", "cuda"], )(a, a) - if jtu.device_under_test() == "tpu": + if (jtu.device_under_test() == "tpu" or + (jtu.device_under_test() == "gpu" and + jtu.is_cuda_compute_capability_at_least("8.0"))): res = export.call(exp)(a, a) self.assertAllClose(res, a + a)