mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pallas] Add support for cross-platform lowering
When implementing this I have discovered that the multi-platform lowering support does not handle the case when the lowering rule for a platform invoke tracing (via `mlir.lower_fun`) and that tracing encounters a primitive that has lowering rules only for a particular platform. To support this, I have added the `LoweringRuleContext.platforms` to override `ModuleContext.platforms` with a potentially narrower set of lowering platforms. Added a test for this scenario.
This commit is contained in:
parent
9b68873436
commit
97db0e758d
@ -585,6 +585,8 @@ class ModuleContext:
|
|||||||
ip: ir.InsertionPoint
|
ip: ir.InsertionPoint
|
||||||
symbol_table: ir.SymbolTable
|
symbol_table: ir.SymbolTable
|
||||||
backend_or_name: str | xb.XlaBackend | None
|
backend_or_name: str | xb.XlaBackend | None
|
||||||
|
# The lowering platforms for the module. Can be more than one only when
|
||||||
|
# exporting.
|
||||||
platforms: Sequence[str]
|
platforms: Sequence[str]
|
||||||
axis_context: AxisContext
|
axis_context: AxisContext
|
||||||
keepalives: list[Any]
|
keepalives: list[Any]
|
||||||
@ -689,6 +691,9 @@ class LoweringRuleContext:
|
|||||||
# module_context.shape_poly_state.dim_vars
|
# module_context.shape_poly_state.dim_vars
|
||||||
dim_var_values: Sequence[ir.Value] = ()
|
dim_var_values: Sequence[ir.Value] = ()
|
||||||
compute_type: str | None = None
|
compute_type: str | None = None
|
||||||
|
# Override module_context.platforms if not None. Used during multi-platform
|
||||||
|
# lowering, when in a scope with a subset of the module_context.platforms.
|
||||||
|
platforms: Sequence[str] | None = None
|
||||||
|
|
||||||
def set_tokens_out(self, tokens_out: TokenSet):
|
def set_tokens_out(self, tokens_out: TokenSet):
|
||||||
assert self.tokens_out is None, 'Should only set `tokens_out` once.'
|
assert self.tokens_out is None, 'Should only set `tokens_out` once.'
|
||||||
@ -1662,7 +1667,7 @@ def lower_per_platform(ctx: LoweringRuleContext,
|
|||||||
rule_args: the args of the lowering rules.
|
rule_args: the args of the lowering rules.
|
||||||
rule_kwargs: the kwargs of the lowering rules.
|
rule_kwargs: the kwargs of the lowering rules.
|
||||||
"""
|
"""
|
||||||
platforms: Sequence[str] = ctx.module_context.platforms
|
platforms: Sequence[str] = ctx.platforms or ctx.module_context.platforms
|
||||||
# Special case the common case (single-platform lowering)
|
# Special case the common case (single-platform lowering)
|
||||||
if len(platforms) == 1:
|
if len(platforms) == 1:
|
||||||
rule = platform_rules.get(platforms[0], default_rule)
|
rule = platform_rules.get(platforms[0], default_rule)
|
||||||
@ -1723,7 +1728,10 @@ def lower_per_platform(ctx: LoweringRuleContext,
|
|||||||
index=rule_idx_op,
|
index=rule_idx_op,
|
||||||
num_branches=len(kept_rules))
|
num_branches=len(kept_rules))
|
||||||
for i, rule in enumerate(kept_rules):
|
for i, rule in enumerate(kept_rules):
|
||||||
inner_ctx = ctx.replace()
|
platforms_for_this_rule = [p
|
||||||
|
for p, rule_idx in platform_to_kept_rules_idx.items()
|
||||||
|
if rule_idx == i]
|
||||||
|
inner_ctx = ctx.replace(platforms=platforms_for_this_rule)
|
||||||
branch = case_op.regions[i].blocks.append()
|
branch = case_op.regions[i].blocks.append()
|
||||||
with ir.InsertionPoint(branch):
|
with ir.InsertionPoint(branch):
|
||||||
output = rule(inner_ctx, *rule_args, **rule_kwargs)
|
output = rule(inner_ctx, *rule_args, **rule_kwargs)
|
||||||
@ -1764,7 +1772,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
|||||||
|
|
||||||
The returned function does not use `avals_out`, so callers may pass any value
|
The returned function does not use `avals_out`, so callers may pass any value
|
||||||
as `avals_out`."""
|
as `avals_out`."""
|
||||||
def f_lowered(ctx, *args, **params):
|
def f_lowered(ctx: LoweringRuleContext, *args, **params):
|
||||||
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
|
||||||
wrapped_fun = lu.wrap_init(f, params)
|
wrapped_fun = lu.wrap_init(f, params)
|
||||||
|
|
||||||
@ -1774,11 +1782,12 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
|||||||
# case, we need to form a jaxpr with leading binders for those axis size
|
# case, we need to form a jaxpr with leading binders for those axis size
|
||||||
# arguments (by computing an InputType and using trace_to_jaxpr_dynamic2),
|
# arguments (by computing an InputType and using trace_to_jaxpr_dynamic2),
|
||||||
# and we need to call jaxpr_subcomp with these arguments made explicit.
|
# and we need to call jaxpr_subcomp with these arguments made explicit.
|
||||||
|
assert ctx.axis_size_env is not None
|
||||||
args = (*ctx.axis_size_env.values(), *args)
|
args = (*ctx.axis_size_env.values(), *args)
|
||||||
idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)}
|
idx = {d: core.DBIdx(i) for i, d in enumerate(ctx.axis_size_env)}
|
||||||
i32_aval = core.ShapedArray((), np.dtype('int32'))
|
i32_aval = core.ShapedArray((), np.dtype('int32'))
|
||||||
implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env)
|
implicit_args = [(i32_aval, False)] * len(ctx.axis_size_env)
|
||||||
explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape))
|
explicit_args = [(a.update(shape=tuple(idx.get(d, d) for d in a.shape)) # type: ignore
|
||||||
if type(a) is core.DShapedArray else a, True)
|
if type(a) is core.DShapedArray else a, True)
|
||||||
for a in ctx.avals_in]
|
for a in ctx.avals_in]
|
||||||
wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args))
|
wrapped_fun = lu.annotate(wrapped_fun, (*implicit_args, *explicit_args))
|
||||||
@ -1787,8 +1796,12 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
|||||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in)
|
||||||
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
|
# TODO(frostig,mattjj): check ctx.avals_out against jaxpr avals out?
|
||||||
|
|
||||||
|
if ctx.platforms is not None:
|
||||||
|
sub_context = ctx.module_context.replace(platforms=ctx.platforms)
|
||||||
|
else:
|
||||||
|
sub_context = ctx.module_context
|
||||||
out, tokens = jaxpr_subcomp(
|
out, tokens = jaxpr_subcomp(
|
||||||
ctx.module_context, jaxpr, ctx.name_stack, ctx.tokens_in,
|
sub_context, jaxpr, ctx.name_stack, ctx.tokens_in,
|
||||||
_ir_consts(consts), *map(wrap_singleton_ir_values, args),
|
_ir_consts(consts), *map(wrap_singleton_ir_values, args),
|
||||||
dim_var_values=ctx.dim_var_values)
|
dim_var_values=ctx.dim_var_values)
|
||||||
ctx.set_tokens_out(tokens)
|
ctx.set_tokens_out(tokens)
|
||||||
|
@ -96,7 +96,7 @@ def pallas_call_tpu_lowering_rule(
|
|||||||
return mosaic.as_tpu_kernel(
|
return mosaic.as_tpu_kernel(
|
||||||
mosaic_module,
|
mosaic_module,
|
||||||
out_avals,
|
out_avals,
|
||||||
backend=ctx.module_context.backend,
|
backend="tpu",
|
||||||
kernel_name=name,
|
kernel_name=name,
|
||||||
cost_estimate=mosaic_params.get("cost_estimate"),
|
cost_estimate=mosaic_params.get("cost_estimate"),
|
||||||
vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"),
|
vmem_limit_bytes=mosaic_params.get("vmem_limit_bytes"),
|
||||||
|
@ -720,38 +720,48 @@ def _pallas_call_lowering(
|
|||||||
impl = partial(_pallas_call_impl, **params, interpret=True)
|
impl = partial(_pallas_call_impl, **params, interpret=True)
|
||||||
return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes)
|
return mlir.lower_fun(impl, multiple_results=True)(ctx, *in_nodes)
|
||||||
|
|
||||||
try:
|
def cpu_lowering(ctx: mlir.LoweringRuleContext,
|
||||||
[platform] = ctx.module_context.platforms
|
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
|
||||||
except ValueError:
|
**params):
|
||||||
raise ValueError(
|
|
||||||
"Can only lower pallas_call on a single platform."
|
|
||||||
) from None
|
|
||||||
|
|
||||||
if platform == "cpu":
|
|
||||||
raise ValueError("Only interpret mode is supported on CPU backend.")
|
raise ValueError("Only interpret mode is supported on CPU backend.")
|
||||||
elif platform == "cuda" or platform == "rocm":
|
|
||||||
|
def tpu_lowering(ctx: mlir.LoweringRuleContext,
|
||||||
|
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
|
||||||
|
**params):
|
||||||
|
try:
|
||||||
|
from jax._src.pallas.mosaic import pallas_call_registration
|
||||||
|
except ImportError:
|
||||||
|
raise _unsupported_lowering_error("tpu")
|
||||||
|
else:
|
||||||
|
return pallas_call_registration.pallas_call_tpu_lowering_rule(
|
||||||
|
ctx, *in_nodes, **params
|
||||||
|
)
|
||||||
|
|
||||||
|
def gpu_lowering(ctx: mlir.LoweringRuleContext,
|
||||||
|
*in_nodes: mlir.ir.Value | Sequence[mlir.ir.Value],
|
||||||
|
**params):
|
||||||
try:
|
try:
|
||||||
if _PALLAS_USE_MOSAIC_GPU.value:
|
if _PALLAS_USE_MOSAIC_GPU.value:
|
||||||
from jax._src.pallas.mosaic_gpu import pallas_call_registration
|
from jax._src.pallas.mosaic_gpu import pallas_call_registration
|
||||||
else:
|
else:
|
||||||
from jax._src.pallas.triton import pallas_call_registration # type: ignore
|
from jax._src.pallas.triton import pallas_call_registration # type: ignore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
raise _unsupported_lowering_error("gpu")
|
||||||
else:
|
else:
|
||||||
return pallas_call_registration.pallas_call_lowering(
|
return pallas_call_registration.pallas_call_lowering(
|
||||||
ctx, *in_nodes, interpret=interpret, **params
|
ctx, *in_nodes, **params
|
||||||
)
|
|
||||||
elif platform == "tpu":
|
|
||||||
try:
|
|
||||||
from jax._src.pallas.mosaic import pallas_call_registration # type: ignore
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
return pallas_call_registration.pallas_call_tpu_lowering_rule(
|
|
||||||
ctx, *in_nodes, interpret=interpret, **params
|
|
||||||
)
|
)
|
||||||
|
|
||||||
raise _unsupported_lowering_error(platform)
|
return mlir.lower_per_platform(ctx, "pallas_call",
|
||||||
|
dict(cpu=cpu_lowering,
|
||||||
|
tpu=tpu_lowering,
|
||||||
|
cuda=gpu_lowering,
|
||||||
|
rocm=gpu_lowering),
|
||||||
|
None, # default_rule
|
||||||
|
effects.no_effects,
|
||||||
|
*in_nodes,
|
||||||
|
interpret=interpret,
|
||||||
|
**params)
|
||||||
|
|
||||||
|
|
||||||
mlir.register_lowering(pallas_call_p, _pallas_call_lowering)
|
mlir.register_lowering(pallas_call_p, _pallas_call_lowering)
|
||||||
|
@ -76,9 +76,8 @@ def pallas_call_lowering(
|
|||||||
)
|
)
|
||||||
triton_params = compiler_params.get("triton", compiler_params)
|
triton_params = compiler_params.get("triton", compiler_params)
|
||||||
num_warps = triton_params.pop("num_warps", 4)
|
num_warps = triton_params.pop("num_warps", 4)
|
||||||
if len(ctx.module_context.platforms) > 1:
|
[lowering_platform] = ctx.platforms or ctx.module_context.platforms
|
||||||
raise NotImplementedError("multi-platform lowering for Pallas kernels")
|
if lowering_platform == "rocm":
|
||||||
if ctx.module_context.platforms[0] == "rocm":
|
|
||||||
num_stages = triton_params.pop("num_stages", 1)
|
num_stages = triton_params.pop("num_stages", 1)
|
||||||
else:
|
else:
|
||||||
num_stages = triton_params.pop("num_stages", 3)
|
num_stages = triton_params.pop("num_stages", 3)
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
import contextlib
|
import contextlib
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import functools
|
import functools
|
||||||
@ -1369,6 +1370,64 @@ class JaxExportTest(jtu.JaxTestCase):
|
|||||||
res2 = exp2.call(x)
|
res2 = exp2.call(x)
|
||||||
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x))
|
self.assertAllClose(res2, _testing_multi_platform_fun_expected(x))
|
||||||
|
|
||||||
|
def test_multi_platform_mlir_lower_fun_with_platform_specific_primitives(self):
|
||||||
|
# A primitive with multiple lowering rules, which themselves involve
|
||||||
|
# tracing primitives with per-platform rules, using mlir.lower_fun.
|
||||||
|
# This situation arises for Pallas lowering.
|
||||||
|
def times_n_lowering(n: int, ctx: mlir.LoweringRuleContext,
|
||||||
|
x: mlir.ir.Value) -> Sequence[mlir.ir.Value]:
|
||||||
|
# Lowering n * x
|
||||||
|
res = x
|
||||||
|
for i in range(n - 1):
|
||||||
|
res = mlir.hlo.AddOp(res, x)
|
||||||
|
return res.results
|
||||||
|
|
||||||
|
times_2 = core.Primitive("__testing_times_2") # x2 for cpu
|
||||||
|
times_2.def_abstract_eval(lambda x: x)
|
||||||
|
# Define lowering rules only for the relevant platforms, ensure there
|
||||||
|
# is no error about missing lowering rules
|
||||||
|
mlir.register_lowering(times_2, functools.partial(times_n_lowering, 2),
|
||||||
|
"cpu")
|
||||||
|
|
||||||
|
times_3 = core.Primitive("__testing_times_3") # x3 for cuda
|
||||||
|
times_3.def_abstract_eval(lambda x: x)
|
||||||
|
mlir.register_lowering(times_3, functools.partial(times_n_lowering, 3),
|
||||||
|
"cuda")
|
||||||
|
|
||||||
|
times_4 = core.Primitive("__testing_times_4") # x4 for tpu
|
||||||
|
times_4.def_abstract_eval(lambda x: x)
|
||||||
|
mlir.register_lowering(times_4, functools.partial(times_n_lowering, 4),
|
||||||
|
"tpu")
|
||||||
|
|
||||||
|
times_2_or_3 = core.Primitive("__testing_times_2_or_3") # x2 for cpu, x3 for cuda
|
||||||
|
times_2_or_3.def_abstract_eval(lambda x: x)
|
||||||
|
mlir.register_lowering(times_2_or_3,
|
||||||
|
mlir.lower_fun(times_2.bind,
|
||||||
|
multiple_results=False), "cpu")
|
||||||
|
mlir.register_lowering(times_2_or_3,
|
||||||
|
mlir.lower_fun(times_3.bind,
|
||||||
|
multiple_results=False), "cuda")
|
||||||
|
|
||||||
|
times_2_or_3_or_4 = core.Primitive("__testing_times_2_or_3_or_4") # x2 for cpu, x3 for cuda, x4 for tpu
|
||||||
|
times_2_or_3_or_4.def_abstract_eval(lambda x: x)
|
||||||
|
times_2_or_3_or_4_lowering_cpu_cuda = mlir.lower_fun(times_2_or_3.bind,
|
||||||
|
multiple_results=False)
|
||||||
|
for platform in ["cpu", "cuda"]:
|
||||||
|
mlir.register_lowering(times_2_or_3_or_4,
|
||||||
|
times_2_or_3_or_4_lowering_cpu_cuda,
|
||||||
|
platform)
|
||||||
|
mlir.register_lowering(times_2_or_3_or_4, mlir.lower_fun(times_4.bind,
|
||||||
|
multiple_results=False),
|
||||||
|
"tpu")
|
||||||
|
|
||||||
|
@jax.jit
|
||||||
|
def f(x):
|
||||||
|
return times_2_or_3_or_4.bind(x)
|
||||||
|
x = np.float32(42.)
|
||||||
|
exp = export.export(f, lowering_platforms=["cpu", "cuda", "tpu"])(x)
|
||||||
|
expected = x * np.float32(dict(cpu=2, gpu=3, tpu=4)[jtu.device_under_test()])
|
||||||
|
self.assertAllClose(exp.call(x), expected)
|
||||||
|
|
||||||
def test_multi_platform_and_poly(self):
|
def test_multi_platform_and_poly(self):
|
||||||
if jtu.test_device_matches(["gpu"]):
|
if jtu.test_device_matches(["gpu"]):
|
||||||
# The export is not applicable to GPU
|
# The export is not applicable to GPU
|
||||||
|
@ -43,11 +43,12 @@ class ExportTest(jtu.JaxTestCase):
|
|||||||
a = np.arange(8)
|
a = np.arange(8)
|
||||||
exp = export.export(
|
exp = export.export(
|
||||||
add_vectors,
|
add_vectors,
|
||||||
# TODO(necula): Make this test work on GPU also
|
lowering_platforms=["tpu", "cuda"],
|
||||||
lowering_platforms=["tpu"],
|
|
||||||
)(a, a)
|
)(a, a)
|
||||||
|
|
||||||
if jtu.device_under_test() == "tpu":
|
if (jtu.device_under_test() == "tpu" or
|
||||||
|
(jtu.device_under_test() == "gpu" and
|
||||||
|
jtu.is_cuda_compute_capability_at_least("8.0"))):
|
||||||
res = export.call(exp)(a, a)
|
res = export.call(exp)(a, a)
|
||||||
self.assertAllClose(res, a + a)
|
self.assertAllClose(res, a + a)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user