From 3d87a01beab8da14051386d7ec8a73cc6d55b884 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 25 Feb 2025 08:39:04 -0800 Subject: [PATCH] [Pallas:MGPU] Adjust warpgroup lowering to the recent emit_pipeline changes The Pallas-level pipelining generates a number of ops we haven't had to deal with before like conditionals, scans, etc. PiperOrigin-RevId: 730899808 --- jax/_src/pallas/mosaic_gpu/lowering.py | 143 +++++++++++++++-------- jax/_src/pallas/mosaic_gpu/primitives.py | 2 +- tests/pallas/mosaic_gpu_test.py | 44 +++---- 3 files changed, 120 insertions(+), 69 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index f13776d1f..21551b435 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -692,7 +692,8 @@ def lower_jaxpr_to_mosaic_gpu( def read_env(atom: jax_core.Atom): return atom.val if isinstance(atom, jax_core.Literal) else env[atom] - def write_env(var: jax_core.Var, val): + def write_env(var: jax_core.Var, val, require_value: bool = True): + env[var] = val # TODO(apaszke): Handle other avals (refs, etc.). if isinstance(aval := var.aval, jax_core.ShapedArray): # TODO(apaszke): Clarify the type invariants for lane semantics? @@ -701,7 +702,12 @@ def lower_jaxpr_to_mosaic_gpu( # Those with empty shapes should be represented by their scalar type. mlir_dtype = mgpu_utils.dtype_to_ir_type(aval.dtype) if not isinstance(val, ir.Value): - raise AssertionError(f"Shaped arrays must be represented by ir.Values, got: {val}") + if require_value: + raise AssertionError(f"Shaped arrays must be represented by ir.Values, got: {val}") + else: + if var.aval.shape: + raise AssertionError("Only scalars can be represented by non-ir.Values") + return # Skip following checks. if aval.shape: if not ir.VectorType.isinstance(val.type): raise AssertionError(f"Non-scalar arrays must be represented by vectors, got: {val.type}") @@ -715,10 +721,9 @@ def lower_jaxpr_to_mosaic_gpu( raise AssertionError(f"Scalars must be represented by non-vector types, got: {val.type}") if val.type != mlir_dtype: raise AssertionError(f"Scalar type must match ShapedArray dtype, got: {val.type} != {mlir_dtype}") - env[var] = val map(write_env, jaxpr.constvars, consts) - map(write_env, jaxpr.invars, args) + map(lambda v, a: write_env(v, a, require_value=False), jaxpr.invars, args) # TODO(justinfu): Handle transform scopes. last_local_name_stack: list[str] = [] named_regions = [] @@ -958,9 +963,9 @@ def _get_lowering_rule_wg(ctx: LoweringRuleContext, x_smem, *leaves, tree): if shape: zero_index = arith_dialect.constant(ir.IndexType.get(), 0) indices = [zero_index for _ in range(len(shape))] + return vector_dialect.load(ty, x_smem, indices) else: - indices = [] - return vector_dialect.load(ty, x_smem, indices) + return memref_dialect.load(x_smem, []) @register_lowering_rule(sp.swap_p, mgpu.ThreadSemantics.Lane) @@ -1019,10 +1024,11 @@ def _swap_lowering_rule_wg( if shape: zero_index = arith_dialect.constant(ir.IndexType.get(), 0) indices = [zero_index for _ in range(len(shape))] + old_value = vector_dialect.load(ty, x_smem, indices) + vector_dialect.store(value, x_smem, indices) else: - indices = [] - old_value = vector_dialect.load(ty, x_smem, indices) - vector_dialect.store(value, x_smem, indices) + old_value = memref_dialect.load(x_smem, []) + memref_dialect.store(value, x_smem, []) return old_value @@ -1051,6 +1057,7 @@ def _slice_lowering_rule( @register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.select_n_p, mgpu.ThreadSemantics.Warpgroup) def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): if len(cases) != 2: raise NotImplementedError( @@ -1059,11 +1066,22 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): ) pred_aval, *cases_avals = ctx.avals_in [out_aval] = ctx.avals_out - pred = _ensure_fa(pred, pred_aval.dtype) - cases = _bcast(*cases, *cases_avals, out_aval) - # ``select`` expects the first case to be the true branch, but ``select_n`` - # orders the cases in reverse. - return pred.select(*reversed(cases)) + if ctx.thread_semantics == mgpu.ThreadSemantics.Lane: + pred = _ensure_fa(pred, pred_aval.dtype) + cases = _bcast(*cases, *cases_avals, out_aval) + # ``select`` expects the first case to be the true branch, but ``select_n`` + # orders the cases in reverse. + return pred.select(*reversed(cases)) + else: + pred = _ensure_ir_value(pred, pred_aval.dtype) + cases = [_ensure_ir_value(c, c_aval.dtype) for c, c_aval in zip(cases, cases_avals)] + # TODO(bchetioui): support implicit broadcast. + if any(a.shape != out_aval.shape for a in ctx.avals_in): + raise NotImplementedError( + "Implicit broadcast not implemented with warpgroup semantics") + # ``select`` expects the first case to be the true branch, but ``select_n`` + # orders the cases in reverse. + return arith_dialect.select(pred, *reversed(cases)) @register_lowering_rule(lax.broadcast_in_dim_p, mgpu.ThreadSemantics.Lane) @@ -1148,11 +1166,13 @@ def _convert_element_type_lowering_rule_wg( elif from_integer and to_integer: if ir.IntegerType(cur_dtype).width > ir.IntegerType(new_dtype).width: convert = arith_dialect.trunci - else: + elif ir.IntegerType(cur_dtype).width < ir.IntegerType(new_dtype).width: if mgpu_utils.is_signed(x_aval.dtype): convert = arith_dialect.extsi else: convert = arith_dialect.extui + else: + convert = lambda _, x: x # signed <-> unsigned conversions elif from_integer and to_float: if mgpu_utils.is_signed(x_aval.dtype): convert = arith_dialect.sitofp @@ -1229,20 +1249,17 @@ for op, si_impl, ui_impl, f_impl in [ arith_dialect.divf, ), (lax.rem_p, arith_dialect.remsi, arith_dialect.remui, arith_dialect.remf), - (lax.and_p, arith_dialect.andi, arith_dialect.andi, None), - (lax.or_p, arith_dialect.ori, arith_dialect.ori, None), - (lax.xor_p, arith_dialect.xori, arith_dialect.xori, None), ( lax.max_p, arith_dialect.maxsi, arith_dialect.maxui, - arith_dialect.maxnumf, + arith_dialect.maximumf, ), ( lax.min_p, arith_dialect.minsi, arith_dialect.minui, - arith_dialect.minnumf, + arith_dialect.minimumf, ), ]: mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( @@ -1252,6 +1269,23 @@ for op, si_impl, ui_impl, f_impl in [ f_impl=f_impl, ) + +def _binary_boolean_op_lowering_rule_wg( + ctx: LoweringRuleContext, x, y, *, impl +): + x, y = _bcast_wg(x, y, *ctx.avals_in, *ctx.avals_out) + return impl(x, y) + +for op, impl in [ + (lax.and_p, arith_dialect.andi), + (lax.or_p, arith_dialect.ori), + (lax.xor_p, arith_dialect.xori), +]: + mosaic_lowering_rules[mgpu.ThreadSemantics.Warpgroup][op] = partial( + _binary_boolean_op_lowering_rule_wg, + impl=impl, + ) + CmpIPred = arith_dialect.CmpIPredicate CmpFPred = arith_dialect.CmpFPredicate @@ -1262,7 +1296,7 @@ def _comparison_lowering_rule_wg( x, y = _bcast_wg(x, y, *ctx.avals_in, *ctx.avals_out) if jnp.issubdtype(x_aval, jnp.signedinteger): return arith_dialect.cmpi(si_pred, x, y) - elif jnp.issubdtype(x_aval, jnp.integer): + elif jnp.issubdtype(x_aval, jnp.integer) or jnp.issubdtype(x_aval, jnp.bool): return arith_dialect.cmpi(ui_pred, x, y) elif jnp.issubdtype(x_aval, jnp.floating): return arith_dialect.cmpf(f_pred, x, y) @@ -1452,6 +1486,19 @@ def _debug_print_lowering_rule( return () +@register_lowering_rule(primitives.debug_print_p, mgpu.ThreadSemantics.Warpgroup) +def _debug_print_lowering_rule( + ctx: LoweringRuleContext, + *args, + fmt, + has_placeholders: bool, +): + del ctx, has_placeholders # Unused. + if args: + raise NotImplementedError("debug_print only supports string messages in warpgroup semantics") + mgpu.debug_print(fmt) + return () + @register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Lane) @register_lowering_rule(primitives.run_scoped_p, mgpu.ThreadSemantics.Warpgroup) @@ -1593,13 +1640,18 @@ def _lower_jaxpr_to_for_loop( out_avals = ctx.avals_out[-len(arg_avals):] is_acc = [isinstance(v, mgpu.WGMMAAccumulator) for v in args] - def as_fas(vals, avals): + def as_values(vals, avals): if is_acc != [isinstance(v, mgpu.WGMMAAccumulator) for v in vals]: raise ValueError("Unexpected loop carry w.r.t. accumulators.") - return [v if a else _ensure_fa(v, av) for a, v, av in zip(is_acc, vals, avals)] + _ensure = ( + _ensure_fa + if ctx.thread_semantics == mgpu.ThreadSemantics.Lane + else _ensure_ir_value + ) + return [v if a else _ensure(v, av) for a, v, av in zip(is_acc, vals, avals)] - @mgpu.fori(length, as_fas(args, arg_avals)) + @mgpu.fori(length, as_values(args, arg_avals)) def loop(loop_index, body_args): if has_loop_index: loop_index = arith_dialect.addi(loop_index, start) @@ -1607,14 +1659,15 @@ def _lower_jaxpr_to_for_loop( else: jaxpr_args = [*consts, *body_args] outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args + ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args, thread_semantics=ctx.thread_semantics, ) - return as_fas(outs, out_avals) + return as_values(outs, out_avals) return loop.results @register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.scan_p, mgpu.ThreadSemantics.Warpgroup) def _scan_lowering_rule( ctx: LoweringRuleContext, *args, @@ -1781,6 +1834,7 @@ def _while_lowering_rule( @register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Lane) +@register_lowering_rule(lax.cond_p, mgpu.ThreadSemantics.Warpgroup) def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): index_aval, *_arg_avals = ctx.avals_in @@ -1800,7 +1854,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): # extract the return types with ir.InsertionPoint(ir.Module.create().body): outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args + ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args, thread_semantics=ctx.thread_semantics, ) yielded_types = [v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out))] del outs @@ -1822,7 +1876,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): for branch, region in zip(branches, regions): with ir.InsertionPoint(region.blocks.append()): outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args, consts=branch.consts + ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args, consts=branch.consts, thread_semantics=ctx.thread_semantics, ) yielded_leaves, yielded_treedef = jax.tree.flatten(_yielded_values(outs, ctx.avals_out)) @@ -1917,48 +1971,41 @@ def _bcast_wg( if not out_aval.shape: return _ensure_ir_value(x, x_aval.dtype), _ensure_ir_value(y, y_aval.dtype) x_dtype = x_aval.dtype - if not isinstance(x, ir.Value) or not ir.VectorType.isinstance(x.type): + if not isinstance(x, ir.Value): if x_aval.weak_type: x_dtype = y_aval.dtype - x = _ensure_vector(x, x_dtype) + x = _ensure_ir_value(x, x_dtype) y_dtype = y_aval.dtype - if not isinstance(y, ir.Value) or not ir.VectorType.isinstance(y.type): + if not isinstance(y, ir.Value): if y_aval.weak_type: y_dtype = x_aval.dtype - y = _ensure_vector(y, y_dtype) - if x_aval.shape != out_aval.shape: - assert not x_aval.shape # TODO(slebedev): Support non-scalar inputs. + y = _ensure_ir_value(y, y_dtype) + if not ir.VectorType.isinstance(x.type): + assert not x_aval.shape x = vector_dialect.splat( ir.VectorType.get(out_aval.shape, mgpu_utils.dtype_to_ir_type(x_dtype)), x, ) - if y_aval.shape != out_aval.shape: - assert not y_aval.shape # TODO(slebedev): Support non-scalar inputs. + elif x_aval.shape != out_aval.shape: + raise NotImplementedError("Unsupported broadcast") + if not ir.VectorType.isinstance(y.type): + assert not y_aval.shape y = vector_dialect.splat( ir.VectorType.get(out_aval.shape, mgpu_utils.dtype_to_ir_type(y_dtype)), y, ) + elif y_aval.shape != out_aval.shape: + raise NotImplementedError("Unsupported broadcast") return x, y -def _ensure_vector(x: object, dtype: jnp.dtype) -> ir.Value: - if isinstance(x, ir.Value) and ir.VectorType.isinstance(x.type): - assert ir.VectorType(x.type).element_type == mgpu_utils.dtype_to_ir_type(dtype) - return x - - if isinstance(x, (np.number, np.ndarray, int, float)): - return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)) - - raise NotImplementedError(f"Unsupported conversion to vector for: {x!r}") - - def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value: if isinstance(x, ir.Value): mlir_dtype = mgpu_utils.dtype_to_ir_type(dtype) if ir.VectorType.isinstance(x.type): assert ir.VectorType(x.type).element_type == mlir_dtype else: - assert x.type == mlir_dtype + assert x.type == mlir_dtype, (x.type, mlir_dtype) return x elif isinstance(x, mgpu.FragmentedArray): assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index e4e2ed8a7..1bbcdf4e2 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -880,7 +880,7 @@ def _jaxpr_call_lowering_rule( program_ids[axis] = lowering._program_id(axis, ctx.module_ctx.squashed_dims) new_module_ctx = dataclasses.replace(ctx.module_ctx, program_ids=program_ids) return lowering.lower_jaxpr_to_mosaic_gpu( - new_module_ctx, ctx.launch_ctx, jaxpr, args + new_module_ctx, ctx.launch_ctx, jaxpr, args, thread_semantics=ctx.thread_semantics, ) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index f77853bc9..51d769d44 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -114,8 +114,6 @@ class PallasCallTest(PallasTest): thread_semantics=[*plgpu.ThreadSemantics], ) def test_binary_op(self, op, dtype, thread_semantics): - if thread_semantics == plgpu.ThreadSemantics.Warpgroup: - self.skipTest("Needs scan_p WG lowering") @functools.partial( pl.pallas_call, @@ -146,8 +144,6 @@ class PallasCallTest(PallasTest): thread_semantics=[*plgpu.ThreadSemantics], ) def test_comparison_op(self, op, dtype, thread_semantics): - if thread_semantics == plgpu.ThreadSemantics.Warpgroup: - self.skipTest("Needs scan_p WG lowering") @functools.partial( pl.pallas_call, @@ -319,8 +315,6 @@ class PallasCallTest(PallasTest): thread_semantics=[*plgpu.ThreadSemantics], ) def test_copy_smem_to_gmem(self, indexer, thread_semantics): - if thread_semantics == plgpu.ThreadSemantics.Warpgroup: - self.skipTest("Needs scan_p WG lowering") @functools.partial( pl.pallas_call, @@ -788,8 +782,6 @@ class PallasCallTest(PallasTest): @parameterized.product(thread_semantics=[*plgpu.ThreadSemantics]) def test_run_scoped(self, thread_semantics): - if thread_semantics == plgpu.ThreadSemantics.Warpgroup: - self.skipTest("Needs scan_p WG lowering") def kernel(x_ref, o_ref): def body(tmp_ref): @@ -906,12 +898,18 @@ class PallasCallTest(PallasTest): x = jnp.arange(128 * 128).astype(jnp.float16).reshape(128, 128) np.testing.assert_array_equal(kernel(x), x) - @parameterized.parameters(False, True) - def test_fori_loop_array(self, force_while): + @parameterized.product( + force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] + ) + def test_fori_loop_array(self, force_while, thread_semantics): + if thread_semantics == plgpu.ThreadSemantics.Warpgroup: + # TODO(apaszke,bchetioui,slebedev): Support while + array carries. + self.skipTest("WG semantics unsupported") @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), ) def kernel(x_ref, o_ref): # Equivalent to x_ref[...] + 2 + 3. @@ -920,12 +918,17 @@ class PallasCallTest(PallasTest): x = jnp.arange(256, dtype=jnp.int32) np.testing.assert_array_equal(kernel(x), x + 2 + 3) - @parameterized.parameters(False, True) - def test_fori_loop_scalar(self, force_while): + @parameterized.product( + force_while=[False, True], thread_semantics=[*plgpu.ThreadSemantics] + ) + def test_fori_loop_scalar(self, force_while, thread_semantics): + if force_while and thread_semantics == plgpu.ThreadSemantics.Warpgroup: + self.skipTest("WG semantics does not support force_while.") @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), ) def kernel(o_ref): # Equivalent to 2 + 3. @@ -1035,25 +1038,26 @@ class PallasCallTest(PallasTest): with self.assertRaisesRegex(ValueError, "has layout .*, when it should be"): kernel() - def test_cond(self): + @parameterized.parameters([*plgpu.ThreadSemantics]) + def test_cond(self, thread_semantics): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + compiler_params=plgpu.GPUCompilerParams(thread_semantics=thread_semantics), ) def kernel(x_ref, o_ref): - acc = _sum_same_dtype(x_ref[...]) jax.lax.cond( - acc % 2 == 0, - lambda: pl.debug_print("acc * 2: {}", acc * 2), - lambda: pl.debug_print("acc: {}", acc), + x_ref[0] % 2 == 0, + lambda: pl.debug_print("acc % 2"), + lambda: pl.debug_print("acc"), ) - o_ref[...] = jnp.broadcast_to(acc, o_ref.shape) + o_ref[...] = jnp.broadcast_to(jnp.asarray(0, dtype=o_ref.dtype), o_ref.shape) - x = jnp.arange(256, dtype=jnp.int32) + x = jnp.full((256,), 1234, dtype=jnp.int32) with self.capture_stdout() as output: jax.block_until_ready(kernel(x)) - self.assertIn("acc * 2:", output()) + self.assertIn("acc % 2", output()) def test_cond_returning_array(self): @functools.partial(