[Mosaic GPU] Automatically format the Mosaic GPU dialect test python code

This allows me to keep using the formatter going forward and not have to bother manually formatting code.

PiperOrigin-RevId: 705024602
This commit is contained in:
Dimitar (Mitko) Asenov 2024-12-11 02:03:31 -08:00 committed by jax authors
parent 66f45d039f
commit 3d9c720d42

View File

@ -74,7 +74,8 @@ def is_mosaic_gpu_op(op: ir.OpView) -> bool:
def workgroup_ptr_ty() -> ir.Type:
workgroup_nvptx_address_space = gpu_address_space_to_nvptx(
gpu.AddressSpace.Workgroup)
gpu.AddressSpace.Workgroup
)
return ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>")
@ -95,7 +96,9 @@ class DialectTest(parameterized.TestCase):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
ir.MemRefType.get((1, 2), ir.F32Type.get()),
llvm.UndefOp(workgroup_ptr_ty()), arrival_count=1)
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=1,
)
with self.assertRaisesRegex(
ir.MLIRError, "must be memref of barrier values"
):
@ -106,7 +109,8 @@ class DialectTest(parameterized.TestCase):
mgpu.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=0)
arrival_count=0,
)
with self.assertRaisesRegex(ir.MLIRError, "value is positive"):
self.module.operation.verify()
@ -115,7 +119,8 @@ class DialectTest(parameterized.TestCase):
mgpu.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(ir.Type.parse(f"!llvm.ptr<{0}>")),
arrival_count=1)
arrival_count=1,
)
with self.assertRaisesRegex(ir.MLIRError, "pointer in address space 3"):
self.module.operation.verify()
@ -124,10 +129,12 @@ class DialectTest(parameterized.TestCase):
mgpu.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=1)
arrival_count=1,
)
self.assertTrue(self.module.operation.verify())
self.assertIsInstance(self.module.body.operations[1],
mgpu.InitializeBarrierOp)
self.assertIsInstance(
self.module.body.operations[1], mgpu.InitializeBarrierOp
)
def test_async_load_op_dest_must_be_contiguous(self):
with ir.InsertionPoint(self.module.body):
@ -575,7 +582,8 @@ class DialectLoweringTest(DialectTest):
mgpu.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=1)
arrival_count=1,
)
lower_mgpu_dialect(self.module)
self.assertEmpty(
@ -591,7 +599,8 @@ class DialectLoweringTest(DialectTest):
mgpu.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=1)
arrival_count=1,
)
scf.yield_([])
lower_mgpu_dialect(self.module)
@ -608,7 +617,8 @@ class DialectLoweringTest(DialectTest):
barriers_ref = mgpu.initialize_barrier(
ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=arrival_count)
arrival_count=arrival_count,
)
# Add a user for barriers_ref to make sure that the lowering keeps types
# consistent.
memref.copy(barriers_ref, barriers_ref)