[MosaicGPU] Move parity computations to a separate function to allow the user to use wait_parity without duplicate code.

PiperOrigin-RevId: 652665738
This commit is contained in:
Christos Perivolaropoulos 2024-07-15 19:07:30 -07:00 committed by jax authors
parent 0690988626
commit 28ffa25496

View File

@ -560,16 +560,20 @@ class Barrier:
scf.yield_([])
def wait(self, expect_wait=False):
i32 = ir.IntegerType.get_signless(32)
parities = memref.load(self.barrier_array.phases, [])
parity, new_parities = self.update_parities(parities)
memref.store(new_parities, self.barrier_array.phases, [])
self.wait_parity(parity, expect_wait=expect_wait)
def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]:
i32 = ir.IntegerType.get_signless(32)
offset_i32 = arith.index_castui(i32, self.offset)
bitmask = arith.shli(c(1, i32), offset_i32)
parity = arith.cmpi(
arith.CmpIPredicate.ne, arith.andi(parities, bitmask), c(0, i32)
)
new_parities = arith.xori(parities, bitmask)
memref.store(new_parities, self.barrier_array.phases, [])
self.wait_parity(parity, expect_wait=expect_wait)
return parity, arith.xori(parities, bitmask)
def arrive(self):
token_ty = ir.Type.parse("!nvgpu.mbarrier.token")