[Mosaic GPU] Implement basic WGMMAFragLayout inference and propagation

PiperOrigin-RevId: 718781860
This commit is contained in:
Dimitar (Mitko) Asenov 2025-01-23 02:47:23 -08:00 committed by jax authors
parent f3e27b6c28
commit 3a411d883a
5 changed files with 89 additions and 67 deletions

View File

@ -55,12 +55,6 @@ from .fragmented_array import (
WGStridedFragLayout as WGStridedFragLayout,
optimization_barrier as optimization_barrier,
)
from .layouts import (
from_strided_fragmented_layout_attr as from_strided_fragmented_layout_attr,
is_strided_fragmented_layout as is_strided_fragmented_layout,
to_splat_fragmented_layout_attr as to_splat_fragmented_layout_attr,
to_strided_fragmented_layout_attr as to_strided_fragmented_layout_attr,
)
from .utils import (
BarrierRef as BarrierRef,
CollectiveBarrierRef as CollectiveBarrierRef,

View File

@ -19,6 +19,7 @@ import enum
from functools import partial
from typing import cast
from jax._src.lib import mosaic_gpu_dialect as mgpu
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import func
@ -53,8 +54,8 @@ def _choose_representative_layout(
"""Chooses an appropriate layout from a given set of possible layouts.
Given the input set of possible layouts, this function extracts a single
representative layout. Currently, this function only works with strided and
splat fragmented layouts.
representative layout. Currently, this function only works with strided,
splat, and WGMMA fragmented layouts.
Returns:
A single layout that can be used to annotate the operation, or None if the
@ -65,20 +66,31 @@ def _choose_representative_layout(
return None
strided_layouts: list[fa.WGStridedFragLayout] = [
layouts_lib.from_strided_fragmented_layout_attr(layout)
layouts_lib.from_layout_attr(layout)
for layout in layouts
if layouts_lib.is_strided_fragmented_layout(layout)
]
splat_layouts: list[fa.WGSplatFragLayout] = list(
map(
layouts_lib.from_splat_fragmented_layout_attr,
layouts_lib.from_layout_attr,
filter(layouts_lib.is_splat_fragmented_layout, layouts),
)
)
if len(splat_layouts) + len(strided_layouts) != len(layouts):
raise ValueError(f"Expected only strided and splat layouts, got {layouts}")
wgmma_layouts: list[fa.WGMMAFragLayout] = list(
map(
layouts_lib.from_layout_attr,
filter(layouts_lib.is_wgmma_fragmented_layout, layouts),
)
)
if len(splat_layouts) + len(strided_layouts) + len(wgmma_layouts) != len(
layouts
):
raise ValueError(
f"Expected only strided, splat, and wgmma layouts, got {layouts}"
)
if len(splat_layouts) > 1:
raise NotImplementedError(
@ -92,14 +104,20 @@ def _choose_representative_layout(
"is not supported."
)
if not splat_layouts:
return layouts_lib.to_strided_fragmented_layout_attr(strided_layouts[0])
if (wgmma_layouts and strided_layouts):
raise NotImplementedError(
"Mixing strided and WGMMA layouts is not supported."
)
if not strided_layouts:
return layouts_lib.to_splat_fragmented_layout_attr(splat_layouts[0])
if wgmma_layouts:
return layouts_lib.to_layout_attr(wgmma_layouts[0])
[strided_layout] = strided_layouts
return layouts_lib.to_strided_fragmented_layout_attr(strided_layout)
if strided_layouts:
[strided_layout] = strided_layouts
return layouts_lib.to_layout_attr(strided_layout)
[splat_layout] = splat_layouts
return layouts_lib.to_layout_attr(splat_layout)
def _in_layout_for_operand(
@ -282,6 +300,16 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts:
return [], [layout]
@partial(_add_layout_inference_rule, mgpu.WGMMAOp)
def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:
layout = layouts_lib.to_layout_attr(fa.WGMMAFragLayout())
if ir.VectorType.isinstance(wgmma_op.a.type):
return [layout, layout], [layout]
return [layout], [layout]
class TraversalOrder(enum.Enum):
"""Traversal orders with respect to the data flow for IR."""

View File

@ -53,7 +53,7 @@ def from_splat_fragmented_layout_attr(attr: ir.Attribute) -> fa.WGSplatFragLayou
def is_splat_fragmented_layout(attr: ir.Attribute) -> bool:
return bool(re.search(_splat_fragmented_layout_attr_pattern, str(attr)))
return bool(_splat_fragmented_layout_attr_pattern.search(str(attr)))
_strided_fragmented_layout_attr_pattern = re.compile(
@ -92,6 +92,10 @@ def from_strided_fragmented_layout_attr(
)
def is_strided_fragmented_layout(attr: ir.Attribute) -> bool:
return bool(_strided_fragmented_layout_attr_pattern.search(str(attr)))
def to_layout_attr(
layout: (
fa.WGSplatFragLayout
@ -121,11 +125,19 @@ _wgmma_fragmented_layout_attr_pattern = re.compile(
)
def is_wgmma_fragmented_layout(attr: ir.Attribute) -> bool:
return bool(_wgmma_fragmented_layout_attr_pattern.search(str(attr)))
_wgmma_row_fragmented_layout_attr_pattern = re.compile(
r"^#mosaic_gpu.WGMMARowFragLayout$"
)
def is_wgmma_row_fragmented_layout(attr: ir.Attribute) -> bool:
return bool(_wgmma_row_fragmented_layout_attr_pattern.search(str(attr)))
def from_layout_attr(
attr: ir.Attribute,
) -> (
@ -135,13 +147,13 @@ def from_layout_attr(
| fa.WGMMARowFragLayout
):
"""Constructs a layout from an MLIR attribute."""
if _splat_fragmented_layout_attr_pattern.fullmatch(str(attr)):
if is_splat_fragmented_layout(attr):
return from_splat_fragmented_layout_attr(attr)
elif _strided_fragmented_layout_attr_pattern.fullmatch(str(attr)):
elif is_strided_fragmented_layout(attr):
return from_strided_fragmented_layout_attr(attr)
elif _wgmma_fragmented_layout_attr_pattern.fullmatch(str(attr)):
elif is_wgmma_fragmented_layout(attr):
return fa.WGMMAFragLayout()
elif _wgmma_row_fragmented_layout_attr_pattern.fullmatch(str(attr)):
elif is_wgmma_row_fragmented_layout(attr):
return fa.WGMMARowFragLayout()
else:
raise NotImplementedError(
@ -149,10 +161,6 @@ def from_layout_attr(
)
def is_strided_fragmented_layout(attr: ir.Attribute) -> bool:
return bool(re.search(_strided_fragmented_layout_attr_pattern, str(attr)))
def in_layouts(op: ir.OpView) -> Sequence[ir.Attribute]:
"""Returns the in_layouts attribute of the given operation.

View File

@ -17,6 +17,7 @@
from typing import Callable
from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.interpreters import mlir as mlir_interpreter
@ -29,8 +30,9 @@ from jax._src.lib.mlir.dialects import memref
from jax._src.lib.mlir.dialects import nvvm
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
import jax.experimental.mosaic.gpu as mgpu
import jax.experimental.mosaic.gpu.utils as mgpu_utils
from jax.experimental.mosaic import gpu as mgpu
from jax.experimental.mosaic.gpu import layouts
from jax.experimental.mosaic.gpu import utils as mgpu_utils
_cext = mgpu.dialect._cext if mgpu.dialect is not None else None
@ -84,6 +86,8 @@ class MosaicGpuTest(parameterized.TestCase):
def setUp(self):
if mgpu.dialect is None:
raise self.skipTest("Test requires Mosaic GPU dialect")
if jax.version._version != jax.lib.__version__:
raise self.skipTest("Test requires matching jax and jaxlib versions")
super().setUp()
self.enter_context(_make_ir_context())
self.enter_context(ir.Location.unknown())
@ -666,9 +670,7 @@ class DialectLoweringTest(MosaicGpuTest):
ty = ir.VectorType.get(shape, elt_ty)
load = vector.load(ty, ref, [zero_index, zero_index])
load.owner.attributes["out_layouts"] = ir.ArrayAttr.get([
mgpu.to_strided_fragmented_layout_attr(
mgpu.WGStridedFragLayout.from_shaped_type(ty)
)
layouts.to_layout_attr(mgpu.WGStridedFragLayout.from_shaped_type(ty))
])
mgpu.lower_mgpu_dialect(self.module, None)

View File

@ -15,6 +15,7 @@
"""Layout inference tests for the Mosaic GPU MLIR dialect."""
from absl.testing import parameterized
import jax
from jax._src import config
from jax._src import test_util as jtu
from jax._src.interpreters import mlir as mlir_interpreter
@ -24,6 +25,7 @@ from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
import jax.experimental.mosaic.gpu as mgpu
from jax.experimental.mosaic.gpu import layouts
config.parse_flags_with_absl()
@ -36,20 +38,13 @@ def _make_ir_context():
return context
def _layout_to_attr(
layout: mgpu.WGSplatFragLayout | mgpu.WGStridedFragLayout,
) -> ir.Attribute:
if isinstance(layout, mgpu.WGSplatFragLayout):
return mgpu.to_splat_fragmented_layout_attr(layout)
else:
return mgpu.to_strided_fragmented_layout_attr(layout)
class LayoutInferenceTest(parameterized.TestCase):
def setUp(self):
if mgpu.dialect is None:
raise self.skipTest("Test requires Mosaic GPU dialect")
if jax.version._version != jax.lib.__version__:
raise self.skipTest("Test requires matching jax and jaxlib versions")
super().setUp()
self.enter_context(_make_ir_context())
self.enter_context(ir.Location.unknown())
@ -72,7 +67,7 @@ class LayoutInferenceTest(parameterized.TestCase):
# strided fragmented layout.
mgpu.infer_layout(self.module)
layout = mgpu.to_strided_fragmented_layout_attr(
layout = layouts.to_layout_attr(
mgpu.WGStridedFragLayout.from_shaped_type(ty)
)
@ -95,9 +90,7 @@ class LayoutInferenceTest(parameterized.TestCase):
# splat fragmented layout.
mgpu.infer_layout(self.module)
layout = mgpu.to_splat_fragmented_layout_attr(
mgpu.WGSplatFragLayout(shape=shape)
)
layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape))
self.assertEmpty(splat0.attributes["in_layouts"])
self.assertSequenceEqual(splat0.attributes["out_layouts"], [layout])
@ -120,7 +113,7 @@ class LayoutInferenceTest(parameterized.TestCase):
c = arith.ConstantOp(ty, ir.DenseElementsAttr.get(attr_list, ty))
add = arith.AddFOp(c, c)
layout = mgpu.to_strided_fragmented_layout_attr(
layout = layouts.to_layout_attr(
mgpu.WGStridedFragLayout(shape=shape, vec_size=1)
)
add.attributes["in_layouts"] = ir.ArrayAttr.get([layout, layout])
@ -148,9 +141,7 @@ class LayoutInferenceTest(parameterized.TestCase):
# splat fragmented layout.
mgpu.infer_layout(self.module)
layout = mgpu.to_splat_fragmented_layout_attr(
mgpu.WGSplatFragLayout(shape=shape)
)
layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape))
self.assertEmpty(splat.attributes["in_layouts"])
self.assertSequenceEqual(splat.attributes["out_layouts"], [layout])
@ -174,7 +165,7 @@ class LayoutInferenceTest(parameterized.TestCase):
func.FuncOp.from_py_func(ty, ty)(body)
[f] = self.module.body.operations
layout_attr = _layout_to_attr(layout)
layout_attr = layouts.to_layout_attr(layout)
f.attributes["in_layouts"] = ir.ArrayAttr.get([layout_attr, layout_attr])
mgpu.infer_layout(self.module)
@ -233,7 +224,13 @@ class LayoutInferenceTest(parameterized.TestCase):
self.assertLen(vector_store.attributes["in_layouts"], 1)
self.assertEmpty(vector_store.attributes["out_layouts"])
def test_infer_layout_picks_strided_layout_over_splat_layout(self):
@parameterized.parameters(
mgpu.WGStridedFragLayout((32, 4), vec_size=1),
mgpu.WGMMAFragLayout(),
)
def test_infer_layout_picks_non_splat_layout_over_splat_layout(
self, layout
):
add = None
def body(lhs, rhs):
@ -247,23 +244,20 @@ class LayoutInferenceTest(parameterized.TestCase):
f = func.FuncOp.from_py_func(ty, ty)(body).func_op
splat_layout = mgpu.to_splat_fragmented_layout_attr(
mgpu.WGSplatFragLayout(shape)
)
strided_layout = mgpu.to_strided_fragmented_layout_attr(
mgpu.WGStridedFragLayout(shape, vec_size=1)
)
splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape))
non_splat_layout = layouts.to_layout_attr(layout)
f.attributes["in_layouts"] = ir.ArrayAttr.get(
[strided_layout, splat_layout]
[non_splat_layout, splat_layout]
)
mgpu.infer_layout(self.module)
self.assertSequenceEqual(
add.attributes["in_layouts"], [strided_layout, strided_layout]
add.attributes["in_layouts"],
[non_splat_layout, non_splat_layout],
)
self.assertSequenceEqual(add.attributes["out_layouts"], [strided_layout])
self.assertSequenceEqual(add.attributes["out_layouts"], [non_splat_layout])
def test_infer_layout_preserves_splat_layouts_in_producers(self):
add0 = add1 = None
@ -279,10 +273,8 @@ class LayoutInferenceTest(parameterized.TestCase):
ty = ir.VectorType.get(shape, elt_type)
f = func.FuncOp.from_py_func(ty, ty)(body).func_op
splat_layout = mgpu.to_splat_fragmented_layout_attr(
mgpu.WGSplatFragLayout(shape)
)
strided_layout = mgpu.to_strided_fragmented_layout_attr(
splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape))
strided_layout = layouts.to_layout_attr(
mgpu.WGStridedFragLayout(shape, vec_size=1)
)
f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout])
@ -311,9 +303,7 @@ class LayoutInferenceTest(parameterized.TestCase):
ty = ir.VectorType.get(shape, ir.BF16Type.get())
f = func.FuncOp.from_py_func(ty, ty)(body).func_op
splat_layout = mgpu.to_splat_fragmented_layout_attr(
mgpu.WGSplatFragLayout(shape)
)
splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape))
f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout])
mgpu.infer_layout(self.module)