[mosaic:gpu] Document behaviour of warp_idx and warpgroup_idx.

Extracted common warp broadcasting code.

PiperOrigin-RevId: 640475128
This commit is contained in:
Chris Jones 2024-06-05 04:08:47 -07:00 committed by jax authors
parent c2a3c0bb80
commit 485fad5679

View File

@ -190,25 +190,28 @@ def thread_idx():
return tidx
def _warp_bcast(val, lane_idx=0):
i32 = ir.IntegerType.get_signless(32)
mask = c(0xFFFFFFFF, i32)
return nvvm.shfl_sync(
val.type, mask, val, c(lane_idx, i32), c(0x1F, i32), nvvm.ShflKind.idx
)
def warp_idx(sync=True):
i32 = ir.IntegerType.get_signless(32)
warp_idx = arith.shrui(thread_idx(), c(5, i32))
if not sync:
return warp_idx
mask = c(0xFFFFFFFF, i32)
return nvvm.shfl_sync(
warp_idx.type, mask, warp_idx, c(0, i32), c(0x1F, i32), nvvm.ShflKind.idx
)
# Performing a warp broadcast improves performance as compiler understands
# that the value is uniform across the warp.
return _warp_bcast(warp_idx) if sync else warp_idx
def warpgroup_idx(sync=True):
i32 = ir.IntegerType.get_signless(32)
wg_idx = arith.shrui(thread_idx(), c(7, i32))
if not sync:
return wg_idx
mask = c(0xFFFFFFFF, i32)
return nvvm.shfl_sync(
wg_idx.type, mask, wg_idx, c(0, i32), c(0x1F, i32), nvvm.ShflKind.idx
)
# Performing a warp broadcast improves performance as compiler understands
# that the value is uniform across the warp.
return _warp_bcast(wg_idx) if sync else wg_idx
# True withon `once()` contexts.