mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[mosaic:gpu] Document behaviour of warp_idx
and warpgroup_idx
.
Extracted common warp broadcasting code. PiperOrigin-RevId: 640475128
This commit is contained in:
parent
c2a3c0bb80
commit
485fad5679
@ -190,25 +190,28 @@ def thread_idx():
|
||||
return tidx
|
||||
|
||||
|
||||
def _warp_bcast(val, lane_idx=0):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
mask = c(0xFFFFFFFF, i32)
|
||||
return nvvm.shfl_sync(
|
||||
val.type, mask, val, c(lane_idx, i32), c(0x1F, i32), nvvm.ShflKind.idx
|
||||
)
|
||||
|
||||
|
||||
def warp_idx(sync=True):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
warp_idx = arith.shrui(thread_idx(), c(5, i32))
|
||||
if not sync:
|
||||
return warp_idx
|
||||
mask = c(0xFFFFFFFF, i32)
|
||||
return nvvm.shfl_sync(
|
||||
warp_idx.type, mask, warp_idx, c(0, i32), c(0x1F, i32), nvvm.ShflKind.idx
|
||||
)
|
||||
# Performing a warp broadcast improves performance as compiler understands
|
||||
# that the value is uniform across the warp.
|
||||
return _warp_bcast(warp_idx) if sync else warp_idx
|
||||
|
||||
|
||||
def warpgroup_idx(sync=True):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
wg_idx = arith.shrui(thread_idx(), c(7, i32))
|
||||
if not sync:
|
||||
return wg_idx
|
||||
mask = c(0xFFFFFFFF, i32)
|
||||
return nvvm.shfl_sync(
|
||||
wg_idx.type, mask, wg_idx, c(0, i32), c(0x1F, i32), nvvm.ShflKind.idx
|
||||
)
|
||||
# Performing a warp broadcast improves performance as compiler understands
|
||||
# that the value is uniform across the warp.
|
||||
return _warp_bcast(wg_idx) if sync else wg_idx
|
||||
|
||||
|
||||
# True withon `once()` contexts.
|
||||
|
Loading…
x
Reference in New Issue
Block a user