mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[MosaicGPU] Move parity computations to a separate function to allow the user to use wait_parity without duplicate code.
PiperOrigin-RevId: 652665738
This commit is contained in:
parent
0690988626
commit
28ffa25496
@ -560,16 +560,20 @@ class Barrier:
|
||||
scf.yield_([])
|
||||
|
||||
def wait(self, expect_wait=False):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
parities = memref.load(self.barrier_array.phases, [])
|
||||
parity, new_parities = self.update_parities(parities)
|
||||
memref.store(new_parities, self.barrier_array.phases, [])
|
||||
self.wait_parity(parity, expect_wait=expect_wait)
|
||||
|
||||
def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]:
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
offset_i32 = arith.index_castui(i32, self.offset)
|
||||
bitmask = arith.shli(c(1, i32), offset_i32)
|
||||
parity = arith.cmpi(
|
||||
arith.CmpIPredicate.ne, arith.andi(parities, bitmask), c(0, i32)
|
||||
)
|
||||
new_parities = arith.xori(parities, bitmask)
|
||||
memref.store(new_parities, self.barrier_array.phases, [])
|
||||
self.wait_parity(parity, expect_wait=expect_wait)
|
||||
return parity, arith.xori(parities, bitmask)
|
||||
|
||||
|
||||
def arrive(self):
|
||||
token_ty = ir.Type.parse("!nvgpu.mbarrier.token")
|
||||
|
Loading…
x
Reference in New Issue
Block a user