mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic GPU][NFC] Add missing wgmma_{commit,wait}_group_sync_aligned
to test.
We need to wait for the results to be available before we can copy them to `smem`, and these instructions are not issued in the lowering of `mgpu_dialect.wgmma`. PiperOrigin-RevId: 725989759
This commit is contained in:
parent
8eea88626f
commit
dd4f396d90
@ -2495,6 +2495,9 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
swizzle=swizzle,
|
||||
)
|
||||
|
||||
nvvm.wgmma_commit_group_sync_aligned()
|
||||
nvvm.wgmma_wait_group_sync_aligned(0)
|
||||
|
||||
# Registers -> SMEM
|
||||
zero_index = arith.constant(ir.IndexType.get(), 0)
|
||||
vector.store(result, result_smem_ref, [zero_index, zero_index])
|
||||
|
Loading…
x
Reference in New Issue
Block a user