[Pallas:MGPU] Adjust warpgroup lowering to the recent emit_pipeline changes

The Pallas-level pipelining generates a number of ops we haven't had to deal with before
like conditionals, scans, etc.

PiperOrigin-RevId: 730899808
This commit is contained in:
Adam Paszke 2025-02-25 08:39:04 -08:00 committed by jax authors
parent a3a48af105
commit 3d87a01bea
3 changed files with 120 additions and 69 deletions

View File

@ -692,7 +692,8 @@ def lower_jaxpr_to_mosaic_gpu(
def read_env(atom: jax_core.Atom):
return atom.val if isinstance(atom, jax_core.Literal) else env[atom]
def write_env(var: jax_core.Var, val):
def write_env(var: jax_core.Var, val, require_value: bool = True):
env[var] = val
# TODO(apaszke): Handle other avals (refs, etc.).
if isinstance(aval := var.aval, jax_core.ShapedArray):
# TODO(apaszke): Clarify the type invariants for lane semantics?
@ -701,7 +702,12 @@ def lower_jaxpr_to_mosaic_gpu(
# Those with empty shapes should be represented by their scalar type.
mlir_dtype = mgpu_utils.dtype_to_ir_type(aval.dtype)
if not isinstance(val, ir.Value):
raise AssertionError(f"Shaped arrays must be represented by ir.Values, got: {val}")
if require_value:
raise AssertionError(f"Shaped arrays must be represented by ir.Values, got: {val}")
else:
if var.aval.shape:
raise AssertionError("Only scalars can be represented by non-ir.Values")
return # Skip following checks.
if aval.shape:
if not ir.VectorType.isinstance(val.type):
raise AssertionError(f"Non-scalar arrays must be represented by vectors, got: {val.type}")
@ -715,10 +721,9 @@ def lower_jaxpr_to_mosaic_gpu(
raise AssertionError(f"Scalars must be represented by non-vector types, got: {val.type}")
if val.type != mlir_dtype:
raise AssertionError(f"Scalar type must match ShapedArray dtype, got: {val.type} != {mlir_dtype}")
env[var] = val
map(write_env, jaxpr.constvars, consts)
map(write_env, jaxpr.invars, args)
map(lambda v, a: write_env(v, a, require_value=False), jaxpr.invars, args)
# TODO(justinfu): Handle transform scopes.
last_local_name_stack: list[str] = []
named_regions = []
@ -958,9 +963,9 @@ def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree):
if shape:
zero_index = arith_dialect.constant(ir.IndexType.get(), 0)
indices = [zero_index for _ in range(len(shape))]
return vector_dialect.load(ty, x_smem, indices)
else:
indices = []
return vector_dialect.load(ty, x_smem, indices)
return memref_dialect.load(x_smem, [])
@register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Lane)
@ -1019,10 +1024,11 @@ def _swap_lowering_rule_wg(
if shape:
zero_index = arith_dialect.constant(ir.IndexType.get(), 0)
indices = [zero_index for _ in range(len(shape))]
old_value = vector_dialect.load(ty, x_smem, indices)
vector_dialect.store(value, x_smem, indices)
else:
indices = []
old_value = vector_dialect.load(ty, x_smem, indices)
vector_dialect.store(value, x_smem, indices)
old_value = memref_dialect.load(x_smem, [])
memref_dialect.store(value, x_smem, [])
return old_value
@ -1051,6 +1057,7 @@ def _slice_lowering_rule(
@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Warpgroup)
def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases):
if len(cases) != 2:
raise NotImplementedError(
@ -1059,11 +1066,22 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases):
)
pred_aval, *cases_avals = ctx.avals_in
[out_aval] = ctx.avals_out
pred = _ensure_fa(pred, pred_aval.dtype)
cases = _bcast(*cases, *cases_avals, out_aval)
# ``select`` expects the first case to be the true branch, but ``select_n``
# orders the cases in reverse.
return pred.select(*reversed(cases))
if ctx.thread_semantics == mgpu.ThreadSemantics.Lane:
pred = _ensure_fa(pred, pred_aval.dtype)
cases = _bcast(*cases, *cases_avals, out_aval)
# ``select`` expects the first case to be the true branch, but ``select_n``
# orders the cases in reverse.
return pred.select(*reversed(cases))
else:
pred = _ensure_ir_value(pred, pred_aval.dtype)
cases = [_ensure_ir_value(c, c_aval.dtype) for c, c_aval in zip(cases, cases_avals)]
# TODO(bchetioui): support implicit broadcast.
if any(a.shape != out_aval.shape for a in ctx.avals_in):
raise NotImplementedError(
"Implicit broadcast not implemented with warpgroup semantics")
# ``select`` expects the first case to be the true branch, but ``select_n``
# orders the cases in reverse.
return arith_dialect.select(pred, *reversed(cases))
@register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Lane)
@ -1148,11 +1166,13 @@ def _convert_element_type_lowering_rule_wg(
elif from_integer and to_integer:
if ir.IntegerType(cur_dtype).width > ir.IntegerType(new_dtype).width:
convert = arith_dialect.trunci
else:
elif ir.IntegerType(cur_dtype).width < ir.IntegerType(new_dtype).width:
if mgpu_utils.is_signed(x_aval.dtype):
convert = arith_dialect.extsi
else:
convert = arith_dialect.extui
else:
convert = lambda _, x: x # signed <-> unsigned conversions
elif from_integer and to_float:
if mgpu_utils.is_signed(x_aval.dtype):
convert = arith_dialect.sitofp
@ -1229,20 +1249,17 @@ for op, si_impl, ui_impl, f_impl in [
arith_dialect.divf,
),
(lax.rem_p, arith_dialect.remsi, arith_dialect.remui, arith_dialect.remf),
(lax.and_p, arith_dialect.andi, arith_dialect.andi, None),
(lax.or_p, arith_dialect.ori, arith_dialect.ori, None),
(lax.xor_p, arith_dialect.xori, arith_dialect.xori, None),
(
lax.max_p,
arith_dialect.maxsi,
arith_dialect.maxui,
arith_dialect.maxnumf,
arith_dialect.maximumf,
),
(
lax.min_p,
arith_dialect.minsi,
arith_dialect.minui,
arith_dialect.minnumf,
arith_dialect.minimumf,
),
]:
mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial(
@ -1252,6 +1269,23 @@ for op, si_impl, ui_impl, f_impl in [
f_impl=f_impl,
)
def _binary_boolean_op_lowering_rule_wg(
ctx: LoweringRuleContext, x, y, *, impl
):
x, y = _bcast_wg(x, y, *ctx.avals_in, *ctx.avals_out)
return impl(x, y)
for op, impl in [
(lax.and_p, arith_dialect.andi),
(lax.or_p, arith_dialect.ori),
(lax.xor_p, arith_dialect.xori),
]:
mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial(
_binary_boolean_op_lowering_rule_wg,
impl=impl,
)
CmpIPred = arith_dialect.CmpIPredicate
CmpFPred = arith_dialect.CmpFPredicate
@ -1262,7 +1296,7 @@ def _comparison_lowering_rule_wg(
x, y = _bcast_wg(x, y, *ctx.avals_in, *ctx.avals_out)
if jnp.issubdtype(x_aval, jnp.signedinteger):
return arith_dialect.cmpi(si_pred, x, y)
elif jnp.issubdtype(x_aval, jnp.integer):
elif jnp.issubdtype(x_aval, jnp.integer) or jnp.issubdtype(x_aval, jnp.bool):
return arith_dialect.cmpi(ui_pred, x, y)
elif jnp.issubdtype(x_aval, jnp.floating):
return arith_dialect.cmpf(f_pred, x, y)
@ -1452,6 +1486,19 @@ def _debug_print_lowering_rule(
return ()
@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Warpgroup)
def _debug_print_lowering_rule(
ctx: LoweringRuleContext,
*args,
fmt,
has_placeholders: bool,
):
del ctx, has_placeholders # Unused.
if args:
raise NotImplementedError("debug_print only supports string messages in warpgroup semantics")
mgpu.debug_print(fmt)
return ()
@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Warpgroup)
@ -1593,13 +1640,18 @@ def _lower_jaxpr_to_for_loop(
out_avals = ctx.avals_out[-len(arg_avals):]
is_acc = [isinstance(v, mgpu.WGMMAAccumulator) for v in args]
def as_fas(vals, avals):
def as_values(vals, avals):
if is_acc != [isinstance(v, mgpu.WGMMAAccumulator) for v in vals]:
raise ValueError("Unexpected loop carry w.r.t. accumulators.")
return [v if a else _ensure_fa(v, av) for a, v, av in zip(is_acc, vals, avals)]
_ensure = (
_ensure_fa
if ctx.thread_semantics == mgpu.ThreadSemantics.Lane
else _ensure_ir_value
)
return [v if a else _ensure(v, av) for a, v, av in zip(is_acc, vals, avals)]
@mgpu.fori(length, as_fas(args, arg_avals))
@mgpu.fori(length, as_values(args, arg_avals))
def loop(loop_index, body_args):
if has_loop_index:
loop_index = arith_dialect.addi(loop_index, start)
@ -1607,14 +1659,15 @@ def _lower_jaxpr_to_for_loop(
else:
jaxpr_args = [*consts, *body_args]
outs = lower_jaxpr_to_mosaic_gpu(
ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args
ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args, thread_semantics=ctx.thread_semantics,
)
return as_fas(outs, out_avals)
return as_values(outs, out_avals)
return loop.results
@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Warpgroup)
def _scan_lowering_rule(
ctx: LoweringRuleContext,
*args,
@ -1781,6 +1834,7 @@ def _while_lowering_rule(
@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Lane)
@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Warpgroup)
def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
index_aval, *_arg_avals = ctx.avals_in
@ -1800,7 +1854,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
# extract the return types
with ir.InsertionPoint(ir.Module.create().body):
outs = lower_jaxpr_to_mosaic_gpu(
ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args
ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args, thread_semantics=ctx.thread_semantics,
)
yielded_types = [v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out))]
del outs
@ -1822,7 +1876,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
for branch, region in zip(branches, regions):
with ir.InsertionPoint(region.blocks.append()):
outs = lower_jaxpr_to_mosaic_gpu(
ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args, consts=branch.consts
ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args, consts=branch.consts, thread_semantics=ctx.thread_semantics,
)
yielded_leaves, yielded_treedef = jax.tree.flatten(_yielded_values(outs, ctx.avals_out))
@ -1917,48 +1971,41 @@ def _bcast_wg(
if not out_aval.shape:
return _ensure_ir_value(x, x_aval.dtype), _ensure_ir_value(y, y_aval.dtype)
x_dtype = x_aval.dtype
if not isinstance(x, ir.Value) or not ir.VectorType.isinstance(x.type):
if not isinstance(x, ir.Value):
if x_aval.weak_type:
x_dtype = y_aval.dtype
x = _ensure_vector(x, x_dtype)
x = _ensure_ir_value(x, x_dtype)
y_dtype = y_aval.dtype
if not isinstance(y, ir.Value) or not ir.VectorType.isinstance(y.type):
if not isinstance(y, ir.Value):
if y_aval.weak_type:
y_dtype = x_aval.dtype
y = _ensure_vector(y, y_dtype)
if x_aval.shape != out_aval.shape:
assert not x_aval.shape # TODO(slebedev): Support non-scalar inputs.
y = _ensure_ir_value(y, y_dtype)
if not ir.VectorType.isinstance(x.type):
assert not x_aval.shape
x = vector_dialect.splat(
ir.VectorType.get(out_aval.shape, mgpu_utils.dtype_to_ir_type(x_dtype)),
x,
)
if y_aval.shape != out_aval.shape:
assert not y_aval.shape # TODO(slebedev): Support non-scalar inputs.
elif x_aval.shape != out_aval.shape:
raise NotImplementedError("Unsupported broadcast")
if not ir.VectorType.isinstance(y.type):
assert not y_aval.shape
y = vector_dialect.splat(
ir.VectorType.get(out_aval.shape, mgpu_utils.dtype_to_ir_type(y_dtype)),
y,
)
elif y_aval.shape != out_aval.shape:
raise NotImplementedError("Unsupported broadcast")
return x, y
def _ensure_vector(x: object, dtype: jnp.dtype) -> ir.Value:
if isinstance(x, ir.Value) and ir.VectorType.isinstance(x.type):
assert ir.VectorType(x.type).element_type == mgpu_utils.dtype_to_ir_type(dtype)
return x
if isinstance(x, (np.number, np.ndarray, int, float)):
return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype))
raise NotImplementedError(f"Unsupported conversion to vector for: {x!r}")
def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value:
if isinstance(x, ir.Value):
mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype)
if ir.VectorType.isinstance(x.type):
assert ir.VectorType(x.type).element_type == mlir_dtype
else:
assert x.type == mlir_dtype
assert x.type == mlir_dtype, (x.type, mlir_dtype)
return x
elif isinstance(x, mgpu.FragmentedArray):
assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype)

View File

@ -880,7 +880,7 @@ def _jaxpr_call_lowering_rule(
program_ids[axis] = lowering._program_id(axis, ctx.module_ctx.squashed_dims)
new_module_ctx = dataclasses.replace(ctx.module_ctx, program_ids=program_ids)
return lowering.lower_jaxpr_to_mosaic_gpu(
new_module_ctx, ctx.launch_ctx, jaxpr, args
new_module_ctx, ctx.launch_ctx, jaxpr, args, thread_semantics=ctx.thread_semantics,
)

View File

@ -114,8 +114,6 @@ class PallasCallTest(PallasTest):
thread_semantics=[*plgpu.ThreadSemantics],
)
def test_binary_op(self, op, dtype, thread_semantics):
if thread_semantics == plgpu.ThreadSemantics.Warpgroup:
self.skipTest("Needs scan_p WG lowering")
@functools.partial(
pl.pallas_call,
@ -146,8 +144,6 @@ class PallasCallTest(PallasTest):
thread_semantics=[*plgpu.ThreadSemantics],
)
def test_comparison_op(self, op, dtype, thread_semantics):
if thread_semantics == plgpu.ThreadSemantics.Warpgroup:
self.skipTest("Needs scan_p WG lowering")
@functools.partial(
pl.pallas_call,
@ -319,8 +315,6 @@ class PallasCallTest(PallasTest):
thread_semantics=[*plgpu.ThreadSemantics],
)
def test_copy_smem_to_gmem(self, indexer, thread_semantics):
if thread_semantics == plgpu.ThreadSemantics.Warpgroup:
self.skipTest("Needs scan_p WG lowering")
@functools.partial(
pl.pallas_call,
@ -788,8 +782,6 @@ class PallasCallTest(PallasTest):
@parameterized.product(thread_semantics=[*plgpu.ThreadSemantics])
def test_run_scoped(self, thread_semantics):
if thread_semantics == plgpu.ThreadSemantics.Warpgroup:
self.skipTest("Needs scan_p WG lowering")
def kernel(x_ref, o_ref):
def body(tmp_ref):
@ -906,12 +898,18 @@ class PallasCallTest(PallasTest):
x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128)
np.testing.assert_array_equal(kernel(x), x)
@parameterized.parameters(False, True)
def test_fori_loop_array(self, force_while):
@parameterized.product(
force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics]
)
def test_fori_loop_array(self, force_while, thread_semantics):
if thread_semantics == plgpu.ThreadSemantics.Warpgroup:
# TODO(apaszke,bchetioui,slebedev): Support while + array carries.
self.skipTest("WG semantics unsupported")
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics),
)
def kernel(x_ref, o_ref):
# Equivalent to x_ref[...] + 2 + 3.
@ -920,12 +918,17 @@ class PallasCallTest(PallasTest):
x = jnp.arange(256, dtype=jnp.int32)
np.testing.assert_array_equal(kernel(x), x + 2 + 3)
@parameterized.parameters(False, True)
def test_fori_loop_scalar(self, force_while):
@parameterized.product(
force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics]
)
def test_fori_loop_scalar(self, force_while, thread_semantics):
if force_while and thread_semantics == plgpu.ThreadSemantics.Warpgroup:
self.skipTest("WG semantics does not support force_while.")
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics),
)
def kernel(o_ref):
# Equivalent to 2 + 3.
@ -1035,25 +1038,26 @@ class PallasCallTest(PallasTest):
with self.assertRaisesRegex(ValueError, "has layout .*, when it should be"):
kernel()
def test_cond(self):
@parameterized.parameters([*plgpu.ThreadSemantics])
def test_cond(self, thread_semantics):
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics),
)
def kernel(x_ref, o_ref):
acc = _sum_same_dtype(x_ref[...])
jax.lax.cond(
acc % 2 == 0,
lambda: pl.debug_print("acc * 2: {}", acc * 2),
lambda: pl.debug_print("acc: {}", acc),
x_ref[0] % 2 == 0,
lambda: pl.debug_print("acc % 2"),
lambda: pl.debug_print("acc"),
)
o_ref[...] = jnp.broadcast_to(acc, o_ref.shape)
o_ref[...] = jnp.broadcast_to(jnp.asarray(0, dtype=o_ref.dtype), o_ref.shape)
x = jnp.arange(256, dtype=jnp.int32)
x = jnp.full((256,), 1234, dtype=jnp.int32)
with self.capture_stdout() as output:
jax.block_until_ready(kernel(x))
self.assertIn("acc * 2:", output())
self.assertIn("acc % 2", output())
def test_cond_returning_array(self):
@functools.partial(