[Mosaic GPU] Clean up imports in gpu_dialect_test.py.

PiperOrigin-RevId: 707549269
This commit is contained in:
Benjamin Chetioui 2024-12-18 07:49:13 -08:00 committed by jax authors
parent 3d54d03529
commit 6a03ea3e73

View File

@ -29,13 +29,9 @@ from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import nvvm
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
from jax.experimental.mosaic.gpu import gpu_address_space_to_nvptx # pylint: disable=g-importing-member,g-multiple-import
from jax.experimental.mosaic.gpu import infer_layout # pylint: disable=g-importing-member,g-multiple-import
from jax.experimental.mosaic.gpu import lower_mgpu_dialect # pylint: disable=g-importing-member,g-multiple-import
from jax.experimental.mosaic.gpu import strided_fragmented_layout # pylint: disable=g-importing-member
import jax.experimental.mosaic.gpu as mgpu
_cext = mgpu._cext if mgpu is not None else None
_cext = mgpu.dialect._cext if mgpu.dialect is not None else None
config.parse_flags_with_absl()
@ -45,7 +41,7 @@ def _make_ir_context():
context = ir.Context()
context.append_dialect_registry(mlir_interpreter.upstream_dialects)
context.load_all_available_dialects()
mgpu.register_dialect(context)
mgpu.dialect.register_dialect(context)
return context
@ -76,7 +72,7 @@ def is_mosaic_gpu_op(op: ir.OpView) -> bool:
def workgroup_ptr_ty() -> ir.Type:
workgroup_nvptx_address_space = gpu_address_space_to_nvptx(
workgroup_nvptx_address_space = mgpu.gpu_address_space_to_nvptx(
gpu.AddressSpace.Workgroup
)
return ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>")
@ -85,7 +81,7 @@ def workgroup_ptr_ty() -> ir.Type:
class MosaicGpuTest(parameterized.TestCase):
def setUp(self):
if mgpu is None:
if mgpu.dialect is None:
raise self.skipTest("Test requires Mosaic GPU dialect")
super().setUp()
self.enter_context(_make_ir_context())
@ -100,7 +96,7 @@ class DialectTest(MosaicGpuTest):
def test_initialize_barrier_op_result_memref_must_wrap_barriers(self):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
mgpu.dialect.initialize_barrier(
ir.MemRefType.get((1, 2), ir.F32Type.get()),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=1,
@ -112,7 +108,7 @@ class DialectTest(MosaicGpuTest):
def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
mgpu.dialect.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=0,
@ -122,7 +118,7 @@ class DialectTest(MosaicGpuTest):
def test_initialize_barrier_op_with_a_non_shared_base_pointer_fails(self):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
mgpu.dialect.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(ir.Type.parse(f"!llvm.ptr<{0}>")),
arrival_count=1,
@ -132,14 +128,14 @@ class DialectTest(MosaicGpuTest):
def test_initialize_barrier_op_with_a_positive_arrival_count_passes(self):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
mgpu.dialect.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=1,
)
self.assertTrue(self.module.operation.verify())
self.assertIsInstance(
self.module.body.operations[1], mgpu.InitializeBarrierOp
self.module.body.operations[1], mgpu.dialect.InitializeBarrierOp
)
def test_async_load_op_dest_must_be_contiguous(self):
@ -156,7 +152,7 @@ class DialectTest(MosaicGpuTest):
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
source,
destination,
barrier,
@ -183,7 +179,7 @@ class DialectTest(MosaicGpuTest):
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
source,
destination,
barrier,
@ -210,7 +206,7 @@ class DialectTest(MosaicGpuTest):
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
source,
destination,
barrier,
@ -238,7 +234,7 @@ class DialectTest(MosaicGpuTest):
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
source,
destination,
barrier,
@ -264,7 +260,7 @@ class DialectTest(MosaicGpuTest):
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
source,
destination,
barrier,
@ -290,7 +286,7 @@ class DialectTest(MosaicGpuTest):
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
source,
destination,
barrier,
@ -316,7 +312,7 @@ class DialectTest(MosaicGpuTest):
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
source,
destination,
barrier,
@ -325,10 +321,10 @@ class DialectTest(MosaicGpuTest):
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([
ir.Attribute.parse(
f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>"
f"#mosaic_gpu.dim<{mgpu.dialect.Dimension.x.name}>"
),
ir.Attribute.parse(
f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>"
f"#mosaic_gpu.dim<{mgpu.dialect.Dimension.x.name}>"
),
]),
)
@ -353,7 +349,7 @@ class DialectTest(MosaicGpuTest):
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
lambda source, destination, *indices: mgpu.dialect.async_store(
source,
destination,
indices,
@ -377,7 +373,7 @@ class DialectTest(MosaicGpuTest):
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
lambda source, destination, *indices: mgpu.dialect.async_store(
source,
destination,
indices,
@ -401,7 +397,7 @@ class DialectTest(MosaicGpuTest):
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
lambda source, destination, *indices: mgpu.dialect.async_store(
source,
destination,
indices,
@ -426,7 +422,7 @@ class DialectTest(MosaicGpuTest):
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
lambda source, destination, *indices: mgpu.dialect.async_store(
source,
destination,
indices,
@ -449,7 +445,7 @@ class DialectTest(MosaicGpuTest):
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
lambda source, destination, *indices: mgpu.dialect.async_store(
source,
destination,
indices,
@ -472,7 +468,7 @@ class DialectTest(MosaicGpuTest):
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
lambda source, destination, *indices: mgpu.dialect.async_store(
source,
destination,
indices,
@ -496,7 +492,7 @@ class DialectTest(MosaicGpuTest):
ir.MemRefType.get([4, 5, 32, 32], ir.BF16Type.get()),
name="wgmma",
)(
lambda accumulator, a, b: mgpu.wgmma(
lambda accumulator, a, b: mgpu.dialect.wgmma(
accumulator,
a,
b,
@ -518,7 +514,7 @@ class DialectTest(MosaicGpuTest):
ir.MemRefType.get([5, 32, 32], ir.BF16Type.get()),
name="wgmma",
)(
lambda accumulator, a, b: mgpu.wgmma(
lambda accumulator, a, b: mgpu.dialect.wgmma(
accumulator,
a,
b,
@ -540,7 +536,7 @@ class DialectTest(MosaicGpuTest):
ir.MemRefType.get([4, 5, 32, 16], ir.BF16Type.get()),
name="wgmma",
)(
lambda accumulator, a, b: mgpu.wgmma(
lambda accumulator, a, b: mgpu.dialect.wgmma(
accumulator,
a,
b,
@ -563,7 +559,7 @@ class DialectTest(MosaicGpuTest):
ir.MemRefType.get([4, 5, 64, 32], ir.BF16Type.get()),
name="wgmma",
)(
lambda accumulator, a, b: mgpu.wgmma(
lambda accumulator, a, b: mgpu.dialect.wgmma(
accumulator,
a,
b,
@ -585,12 +581,12 @@ class DialectLoweringTest(MosaicGpuTest):
def test_lowering_removes_mosaic_gpu_ops(self):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
mgpu.dialect.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=1,
)
lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module)
self.assertEmpty(
list(filter(is_mosaic_gpu_op, self.module.body.operations))
@ -602,13 +598,13 @@ class DialectLoweringTest(MosaicGpuTest):
cst_true = arith.constant(bool_type, ir.IntegerAttr.get(bool_type, 1))
if_op = scf.IfOp(cst_true)
with ir.InsertionPoint(if_op.then_block):
mgpu.initialize_barrier(
mgpu.dialect.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=1,
)
scf.yield_([])
lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module)
self.assertEmpty(
list(filter(is_mosaic_gpu_op, if_op.then_block.operations))
@ -620,7 +616,7 @@ class DialectLoweringTest(MosaicGpuTest):
arrival_count = 1337
with ir.InsertionPoint(self.module.body):
barriers_ref = mgpu.initialize_barrier(
barriers_ref = mgpu.dialect.initialize_barrier(
ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")),
llvm.UndefOp(workgroup_ptr_ty()),
arrival_count=arrival_count,
@ -630,7 +626,7 @@ class DialectLoweringTest(MosaicGpuTest):
memref.copy(barriers_ref, barriers_ref)
self.assertTrue(self.module.operation.verify())
lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module)
self.assertTrue(self.module.operation.verify())
all_mbarrier_init_shared_ops = find_if(
@ -658,7 +654,7 @@ class DialectLoweringTest(MosaicGpuTest):
with self.assertRaisesRegex(
ValueError, "missing a layout and can not be lowered"
):
lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module)
def test_lowering_eliminates_layouts(self):
shape = (4, 128)
@ -669,10 +665,10 @@ class DialectLoweringTest(MosaicGpuTest):
ty = ir.VectorType.get(shape, elt_ty)
load = vector.load(ty, ref, [zero_index, zero_index])
load.owner.attributes["out_layouts"] = ir.ArrayAttr.get(
[strided_fragmented_layout()]
[mgpu.strided_fragmented_layout()]
)
lower_mgpu_dialect(self.module)
mgpu.lower_mgpu_dialect(self.module)
all_ops_with_layouts = find_if(
self.module,
@ -692,8 +688,8 @@ class DialectLoweringTest(MosaicGpuTest):
array = vector.load(ty, ref, [zero_index, zero_index])
vector.store(array, ref, [zero_index, zero_index])
infer_layout(self.module)
lower_mgpu_dialect(self.module)
mgpu.infer_layout(self.module)
mgpu.lower_mgpu_dialect(self.module)
all_loads = find_if(
self.module,