[Pallas:MGPU] Don't recreate single_thread_predicate at every rule

While the predicate helps us avoid branching, it can be created once per
block. Its creation uses `*.sync` instructions, which are not DCEd by
LLVM and end up polluting the final code.

PiperOrigin-RevId: 731253109
This commit is contained in:
Adam Paszke 2025-02-26 04:01:32 -08:00 committed by jax authors
parent 7a34f1cedc
commit 3251b55ef2
2 changed files with 9 additions and 4 deletions

View File

@ -236,6 +236,7 @@ class ModuleContext:
grid_names: Sequence[Hashable] | None
program_ids: Sequence[ir.Value] | None
approx_math: bool
single_wg_lane_predicate: ir.Value
runtime_smem: ir.Value # ir.MemRefType
smem_used_bytes: int
runtime_barriers: MutableMapping[
@ -308,7 +309,6 @@ class ModuleContext:
class LoweringRuleContext:
module_ctx: ModuleContext
launch_ctx: mgpu.LaunchContext
predicate: ir.Value
prim: jax_core.Primitive
avals_in: Sequence[jax_core.ShapedArray]
avals_out: Sequence[jax_core.ShapedArray]
@ -595,6 +595,7 @@ def lower_jaxpr_to_module(
grid_names,
[_program_id(axis, squashed_dims) for axis in range(len(grid))],
approx_math,
mgpu.single_thread_predicate(per_block=False),
runtime_smem,
smem_used_bytes=0,
runtime_barriers=grouped_barriers,
@ -753,7 +754,6 @@ def lower_jaxpr_to_mosaic_gpu(
rule_ctx = LoweringRuleContext(
module_ctx,
launch_ctx,
predicate=mgpu.single_thread_predicate(per_block=False),
avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars],
avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars],
prim=eqn.primitive,

View File

@ -87,7 +87,7 @@ def _copy_smem_to_gmem_lowering(
dst_transforms_treedef,
has_user_predicate,
):
predicate = ctx.predicate
predicate = ctx.module_ctx.single_wg_lane_predicate
if has_user_predicate:
flat_args, user_predicate = flat_args[:-1], flat_args[-1]
predicate = arith_dialect.andi(
@ -273,7 +273,12 @@ def _copy_gmem_to_smem_lowering(
bytes //= WARPGROUP_SIZE
barrier.arrive_expect_tx(bytes)
ctx.launch_ctx.async_copy(
src_ref=src, dst_ref=dst, barrier=barrier, arrive=False, **copy_params
src_ref=src,
dst_ref=dst,
barrier=barrier,
arrive=False,
predicate=ctx.module_ctx.single_wg_lane_predicate,
**copy_params,
)
return ()