mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Clean up imports in gpu_dialect_test.py
.
PiperOrigin-RevId: 707549269
This commit is contained in:
parent
3d54d03529
commit
6a03ea3e73
@ -29,13 +29,9 @@ from jax._src.lib.mlir.dialects import memref
|
||||
from jax._src.lib.mlir.dialects import nvvm
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
|
||||
from jax.experimental.mosaic.gpu import gpu_address_space_to_nvptx # pylint: disable=g-importing-member,g-multiple-import
|
||||
from jax.experimental.mosaic.gpu import infer_layout # pylint: disable=g-importing-member,g-multiple-import
|
||||
from jax.experimental.mosaic.gpu import lower_mgpu_dialect # pylint: disable=g-importing-member,g-multiple-import
|
||||
from jax.experimental.mosaic.gpu import strided_fragmented_layout # pylint: disable=g-importing-member
|
||||
import jax.experimental.mosaic.gpu as mgpu
|
||||
|
||||
_cext = mgpu._cext if mgpu is not None else None
|
||||
_cext = mgpu.dialect._cext if mgpu.dialect is not None else None
|
||||
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
@ -45,7 +41,7 @@ def _make_ir_context():
|
||||
context = ir.Context()
|
||||
context.append_dialect_registry(mlir_interpreter.upstream_dialects)
|
||||
context.load_all_available_dialects()
|
||||
mgpu.register_dialect(context)
|
||||
mgpu.dialect.register_dialect(context)
|
||||
return context
|
||||
|
||||
|
||||
@ -76,7 +72,7 @@ def is_mosaic_gpu_op(op: ir.OpView) -> bool:
|
||||
|
||||
|
||||
def workgroup_ptr_ty() -> ir.Type:
|
||||
workgroup_nvptx_address_space = gpu_address_space_to_nvptx(
|
||||
workgroup_nvptx_address_space = mgpu.gpu_address_space_to_nvptx(
|
||||
gpu.AddressSpace.Workgroup
|
||||
)
|
||||
return ir.Type.parse(f"!llvm.ptr<{workgroup_nvptx_address_space}>")
|
||||
@ -85,7 +81,7 @@ def workgroup_ptr_ty() -> ir.Type:
|
||||
class MosaicGpuTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if mgpu is None:
|
||||
if mgpu.dialect is None:
|
||||
raise self.skipTest("Test requires Mosaic GPU dialect")
|
||||
super().setUp()
|
||||
self.enter_context(_make_ir_context())
|
||||
@ -100,7 +96,7 @@ class DialectTest(MosaicGpuTest):
|
||||
|
||||
def test_initialize_barrier_op_result_memref_must_wrap_barriers(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
mgpu.initialize_barrier(
|
||||
mgpu.dialect.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.F32Type.get()),
|
||||
llvm.UndefOp(workgroup_ptr_ty()),
|
||||
arrival_count=1,
|
||||
@ -112,7 +108,7 @@ class DialectTest(MosaicGpuTest):
|
||||
|
||||
def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
mgpu.initialize_barrier(
|
||||
mgpu.dialect.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
llvm.UndefOp(workgroup_ptr_ty()),
|
||||
arrival_count=0,
|
||||
@ -122,7 +118,7 @@ class DialectTest(MosaicGpuTest):
|
||||
|
||||
def test_initialize_barrier_op_with_a_non_shared_base_pointer_fails(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
mgpu.initialize_barrier(
|
||||
mgpu.dialect.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
llvm.UndefOp(ir.Type.parse(f"!llvm.ptr<{0}>")),
|
||||
arrival_count=1,
|
||||
@ -132,14 +128,14 @@ class DialectTest(MosaicGpuTest):
|
||||
|
||||
def test_initialize_barrier_op_with_a_positive_arrival_count_passes(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
mgpu.initialize_barrier(
|
||||
mgpu.dialect.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
llvm.UndefOp(workgroup_ptr_ty()),
|
||||
arrival_count=1,
|
||||
)
|
||||
self.assertTrue(self.module.operation.verify())
|
||||
self.assertIsInstance(
|
||||
self.module.body.operations[1], mgpu.InitializeBarrierOp
|
||||
self.module.body.operations[1], mgpu.dialect.InitializeBarrierOp
|
||||
)
|
||||
|
||||
def test_async_load_op_dest_must_be_contiguous(self):
|
||||
@ -156,7 +152,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_load",
|
||||
)(
|
||||
lambda source, destination, barrier, *indices: mgpu.async_load(
|
||||
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
|
||||
source,
|
||||
destination,
|
||||
barrier,
|
||||
@ -183,7 +179,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_load",
|
||||
)(
|
||||
lambda source, destination, barrier, *indices: mgpu.async_load(
|
||||
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
|
||||
source,
|
||||
destination,
|
||||
barrier,
|
||||
@ -210,7 +206,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_load",
|
||||
)(
|
||||
lambda source, destination, barrier, *indices: mgpu.async_load(
|
||||
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
|
||||
source,
|
||||
destination,
|
||||
barrier,
|
||||
@ -238,7 +234,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_load",
|
||||
)(
|
||||
lambda source, destination, barrier, *indices: mgpu.async_load(
|
||||
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
|
||||
source,
|
||||
destination,
|
||||
barrier,
|
||||
@ -264,7 +260,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_load",
|
||||
)(
|
||||
lambda source, destination, barrier, *indices: mgpu.async_load(
|
||||
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
|
||||
source,
|
||||
destination,
|
||||
barrier,
|
||||
@ -290,7 +286,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_load",
|
||||
)(
|
||||
lambda source, destination, barrier, *indices: mgpu.async_load(
|
||||
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
|
||||
source,
|
||||
destination,
|
||||
barrier,
|
||||
@ -316,7 +312,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_load",
|
||||
)(
|
||||
lambda source, destination, barrier, *indices: mgpu.async_load(
|
||||
lambda source, destination, barrier, *indices: mgpu.dialect.async_load(
|
||||
source,
|
||||
destination,
|
||||
barrier,
|
||||
@ -325,10 +321,10 @@ class DialectTest(MosaicGpuTest):
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([
|
||||
ir.Attribute.parse(
|
||||
f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>"
|
||||
f"#mosaic_gpu.dim<{mgpu.dialect.Dimension.x.name}>"
|
||||
),
|
||||
ir.Attribute.parse(
|
||||
f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>"
|
||||
f"#mosaic_gpu.dim<{mgpu.dialect.Dimension.x.name}>"
|
||||
),
|
||||
]),
|
||||
)
|
||||
@ -353,7 +349,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_store",
|
||||
)(
|
||||
lambda source, destination, *indices: mgpu.async_store(
|
||||
lambda source, destination, *indices: mgpu.dialect.async_store(
|
||||
source,
|
||||
destination,
|
||||
indices,
|
||||
@ -377,7 +373,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_store",
|
||||
)(
|
||||
lambda source, destination, *indices: mgpu.async_store(
|
||||
lambda source, destination, *indices: mgpu.dialect.async_store(
|
||||
source,
|
||||
destination,
|
||||
indices,
|
||||
@ -401,7 +397,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_store",
|
||||
)(
|
||||
lambda source, destination, *indices: mgpu.async_store(
|
||||
lambda source, destination, *indices: mgpu.dialect.async_store(
|
||||
source,
|
||||
destination,
|
||||
indices,
|
||||
@ -426,7 +422,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_store",
|
||||
)(
|
||||
lambda source, destination, *indices: mgpu.async_store(
|
||||
lambda source, destination, *indices: mgpu.dialect.async_store(
|
||||
source,
|
||||
destination,
|
||||
indices,
|
||||
@ -449,7 +445,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_store",
|
||||
)(
|
||||
lambda source, destination, *indices: mgpu.async_store(
|
||||
lambda source, destination, *indices: mgpu.dialect.async_store(
|
||||
source,
|
||||
destination,
|
||||
indices,
|
||||
@ -472,7 +468,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_store",
|
||||
)(
|
||||
lambda source, destination, *indices: mgpu.async_store(
|
||||
lambda source, destination, *indices: mgpu.dialect.async_store(
|
||||
source,
|
||||
destination,
|
||||
indices,
|
||||
@ -496,7 +492,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.MemRefType.get([4, 5, 32, 32], ir.BF16Type.get()),
|
||||
name="wgmma",
|
||||
)(
|
||||
lambda accumulator, a, b: mgpu.wgmma(
|
||||
lambda accumulator, a, b: mgpu.dialect.wgmma(
|
||||
accumulator,
|
||||
a,
|
||||
b,
|
||||
@ -518,7 +514,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.MemRefType.get([5, 32, 32], ir.BF16Type.get()),
|
||||
name="wgmma",
|
||||
)(
|
||||
lambda accumulator, a, b: mgpu.wgmma(
|
||||
lambda accumulator, a, b: mgpu.dialect.wgmma(
|
||||
accumulator,
|
||||
a,
|
||||
b,
|
||||
@ -540,7 +536,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.MemRefType.get([4, 5, 32, 16], ir.BF16Type.get()),
|
||||
name="wgmma",
|
||||
)(
|
||||
lambda accumulator, a, b: mgpu.wgmma(
|
||||
lambda accumulator, a, b: mgpu.dialect.wgmma(
|
||||
accumulator,
|
||||
a,
|
||||
b,
|
||||
@ -563,7 +559,7 @@ class DialectTest(MosaicGpuTest):
|
||||
ir.MemRefType.get([4, 5, 64, 32], ir.BF16Type.get()),
|
||||
name="wgmma",
|
||||
)(
|
||||
lambda accumulator, a, b: mgpu.wgmma(
|
||||
lambda accumulator, a, b: mgpu.dialect.wgmma(
|
||||
accumulator,
|
||||
a,
|
||||
b,
|
||||
@ -585,12 +581,12 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
|
||||
def test_lowering_removes_mosaic_gpu_ops(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
mgpu.initialize_barrier(
|
||||
mgpu.dialect.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
llvm.UndefOp(workgroup_ptr_ty()),
|
||||
arrival_count=1,
|
||||
)
|
||||
lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
|
||||
self.assertEmpty(
|
||||
list(filter(is_mosaic_gpu_op, self.module.body.operations))
|
||||
@ -602,13 +598,13 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
cst_true = arith.constant(bool_type, ir.IntegerAttr.get(bool_type, 1))
|
||||
if_op = scf.IfOp(cst_true)
|
||||
with ir.InsertionPoint(if_op.then_block):
|
||||
mgpu.initialize_barrier(
|
||||
mgpu.dialect.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
llvm.UndefOp(workgroup_ptr_ty()),
|
||||
arrival_count=1,
|
||||
)
|
||||
scf.yield_([])
|
||||
lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
|
||||
self.assertEmpty(
|
||||
list(filter(is_mosaic_gpu_op, if_op.then_block.operations))
|
||||
@ -620,7 +616,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
arrival_count = 1337
|
||||
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
barriers_ref = mgpu.initialize_barrier(
|
||||
barriers_ref = mgpu.dialect.initialize_barrier(
|
||||
ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
llvm.UndefOp(workgroup_ptr_ty()),
|
||||
arrival_count=arrival_count,
|
||||
@ -630,7 +626,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
memref.copy(barriers_ref, barriers_ref)
|
||||
|
||||
self.assertTrue(self.module.operation.verify())
|
||||
lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
self.assertTrue(self.module.operation.verify())
|
||||
|
||||
all_mbarrier_init_shared_ops = find_if(
|
||||
@ -658,7 +654,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "missing a layout and can not be lowered"
|
||||
):
|
||||
lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
|
||||
def test_lowering_eliminates_layouts(self):
|
||||
shape = (4, 128)
|
||||
@ -669,10 +665,10 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
ty = ir.VectorType.get(shape, elt_ty)
|
||||
load = vector.load(ty, ref, [zero_index, zero_index])
|
||||
load.owner.attributes["out_layouts"] = ir.ArrayAttr.get(
|
||||
[strided_fragmented_layout()]
|
||||
[mgpu.strided_fragmented_layout()]
|
||||
)
|
||||
|
||||
lower_mgpu_dialect(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
|
||||
all_ops_with_layouts = find_if(
|
||||
self.module,
|
||||
@ -692,8 +688,8 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
array = vector.load(ty, ref, [zero_index, zero_index])
|
||||
vector.store(array, ref, [zero_index, zero_index])
|
||||
|
||||
infer_layout(self.module)
|
||||
lower_mgpu_dialect(self.module)
|
||||
mgpu.infer_layout(self.module)
|
||||
mgpu.lower_mgpu_dialect(self.module)
|
||||
|
||||
all_loads = find_if(
|
||||
self.module,
|
||||
|
Loading…
x
Reference in New Issue
Block a user