mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Pallas:MGPU] Don't recreate single_thread_predicate at every rule
While the predicate helps us avoid branching, it can be created once per block. Its creation uses `*.sync` instructions, which are not DCEd by LLVM and end up polluting the final code. PiperOrigin-RevId: 731253109
This commit is contained in:
parent
7a34f1cedc
commit
3251b55ef2
@ -236,6 +236,7 @@ class ModuleContext:
|
||||
grid_names: Sequence[Hashable] | None
|
||||
program_ids: Sequence[ir.Value] | None
|
||||
approx_math: bool
|
||||
single_wg_lane_predicate: ir.Value
|
||||
runtime_smem: ir.Value # ir.MemRefType
|
||||
smem_used_bytes: int
|
||||
runtime_barriers: MutableMapping[
|
||||
@ -308,7 +309,6 @@ class ModuleContext:
|
||||
class LoweringRuleContext:
|
||||
module_ctx: ModuleContext
|
||||
launch_ctx: mgpu.LaunchContext
|
||||
predicate: ir.Value
|
||||
prim: jax_core.Primitive
|
||||
avals_in: Sequence[jax_core.ShapedArray]
|
||||
avals_out: Sequence[jax_core.ShapedArray]
|
||||
@ -595,6 +595,7 @@ def lower_jaxpr_to_module(
|
||||
grid_names,
|
||||
[_program_id(axis, squashed_dims) for axis in range(len(grid))],
|
||||
approx_math,
|
||||
mgpu.single_thread_predicate(per_block=False),
|
||||
runtime_smem,
|
||||
smem_used_bytes=0,
|
||||
runtime_barriers=grouped_barriers,
|
||||
@ -753,7 +754,6 @@ def lower_jaxpr_to_mosaic_gpu(
|
||||
rule_ctx = LoweringRuleContext(
|
||||
module_ctx,
|
||||
launch_ctx,
|
||||
predicate=mgpu.single_thread_predicate(per_block=False),
|
||||
avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars],
|
||||
avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars],
|
||||
prim=eqn.primitive,
|
||||
|
@ -87,7 +87,7 @@ def _copy_smem_to_gmem_lowering(
|
||||
dst_transforms_treedef,
|
||||
has_user_predicate,
|
||||
):
|
||||
predicate = ctx.predicate
|
||||
predicate = ctx.module_ctx.single_wg_lane_predicate
|
||||
if has_user_predicate:
|
||||
flat_args, user_predicate = flat_args[:-1], flat_args[-1]
|
||||
predicate = arith_dialect.andi(
|
||||
@ -273,7 +273,12 @@ def _copy_gmem_to_smem_lowering(
|
||||
bytes //= WARPGROUP_SIZE
|
||||
barrier.arrive_expect_tx(bytes)
|
||||
ctx.launch_ctx.async_copy(
|
||||
src_ref=src, dst_ref=dst, barrier=barrier, arrive=False, **copy_params
|
||||
src_ref=src,
|
||||
dst_ref=dst,
|
||||
barrier=barrier,
|
||||
arrive=False,
|
||||
predicate=ctx.module_ctx.single_wg_lane_predicate,
|
||||
**copy_params,
|
||||
)
|
||||
return ()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user