mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Mosaic GPU] Automatically format the Mosaic GPU dialect test python code
This allows me to keep using the formatter going forward and not have to bother manually formatting code. PiperOrigin-RevId: 705024602
This commit is contained in:
parent
66f45d039f
commit
3d9c720d42
@ -74,7 +74,8 @@ def is_mosaic_gpu_op(op: ir.OpView) -> bool:
|
||||
|
||||
def workgroup_ptr_ty() -> ir.Type:
|
||||
workgroup_nvptx_address_space = gpu_address_space_to_nvptx(
|
||||
gpu.AddressSpace.Workgroup)
|
||||
gpu.AddressSpace.Workgroup
|
||||
)
|
||||
return ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>")
|
||||
|
||||
|
||||
@ -95,7 +96,9 @@ class DialectTest(parameterized.TestCase):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
mgpu.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.F32Type.get()),
|
||||
llvm.UndefOp(workgroup_ptr_ty()), arrival_count=1)
|
||||
llvm.UndefOp(workgroup_ptr_ty()),
|
||||
arrival_count=1,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ir.MLIRError, "must be memref of barrier values"
|
||||
):
|
||||
@ -106,7 +109,8 @@ class DialectTest(parameterized.TestCase):
|
||||
mgpu.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
llvm.UndefOp(workgroup_ptr_ty()),
|
||||
arrival_count=0)
|
||||
arrival_count=0,
|
||||
)
|
||||
with self.assertRaisesRegex(ir.MLIRError, "value is positive"):
|
||||
self.module.operation.verify()
|
||||
|
||||
@ -115,7 +119,8 @@ class DialectTest(parameterized.TestCase):
|
||||
mgpu.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
llvm.UndefOp(ir.Type.parse(f"!llvm.ptr<{0}>")),
|
||||
arrival_count=1)
|
||||
arrival_count=1,
|
||||
)
|
||||
with self.assertRaisesRegex(ir.MLIRError, "pointer in address space 3"):
|
||||
self.module.operation.verify()
|
||||
|
||||
@ -124,10 +129,12 @@ class DialectTest(parameterized.TestCase):
|
||||
mgpu.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
llvm.UndefOp(workgroup_ptr_ty()),
|
||||
arrival_count=1)
|
||||
arrival_count=1,
|
||||
)
|
||||
self.assertTrue(self.module.operation.verify())
|
||||
self.assertIsInstance(self.module.body.operations[1],
|
||||
mgpu.InitializeBarrierOp)
|
||||
self.assertIsInstance(
|
||||
self.module.body.operations[1], mgpu.InitializeBarrierOp
|
||||
)
|
||||
|
||||
def test_async_load_op_dest_must_be_contiguous(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
@ -575,7 +582,8 @@ class DialectLoweringTest(DialectTest):
|
||||
mgpu.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
llvm.UndefOp(workgroup_ptr_ty()),
|
||||
arrival_count=1)
|
||||
arrival_count=1,
|
||||
)
|
||||
lower_mgpu_dialect(self.module)
|
||||
|
||||
self.assertEmpty(
|
||||
@ -591,7 +599,8 @@ class DialectLoweringTest(DialectTest):
|
||||
mgpu.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
llvm.UndefOp(workgroup_ptr_ty()),
|
||||
arrival_count=1)
|
||||
arrival_count=1,
|
||||
)
|
||||
scf.yield_([])
|
||||
lower_mgpu_dialect(self.module)
|
||||
|
||||
@ -608,7 +617,8 @@ class DialectLoweringTest(DialectTest):
|
||||
barriers_ref = mgpu.initialize_barrier(
|
||||
ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
llvm.UndefOp(workgroup_ptr_ty()),
|
||||
arrival_count=arrival_count)
|
||||
arrival_count=arrival_count,
|
||||
)
|
||||
# Add a user for barriers_ref to make sure that the lowering keeps types
|
||||
# consistent.
|
||||
memref.copy(barriers_ref, barriers_ref)
|
||||
|
Loading…
x
Reference in New Issue
Block a user