mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic GPU] Implement basic WGMMAFragLayout inference and propagation
PiperOrigin-RevId: 718781860
This commit is contained in:
parent
f3e27b6c28
commit
3a411d883a
@ -55,12 +55,6 @@ from .fragmented_array import (
|
||||
WGStridedFragLayout as WGStridedFragLayout,
|
||||
optimization_barrier as optimization_barrier,
|
||||
)
|
||||
from .layouts import (
|
||||
from_strided_fragmented_layout_attr as from_strided_fragmented_layout_attr,
|
||||
is_strided_fragmented_layout as is_strided_fragmented_layout,
|
||||
to_splat_fragmented_layout_attr as to_splat_fragmented_layout_attr,
|
||||
to_strided_fragmented_layout_attr as to_strided_fragmented_layout_attr,
|
||||
)
|
||||
from .utils import (
|
||||
BarrierRef as BarrierRef,
|
||||
CollectiveBarrierRef as CollectiveBarrierRef,
|
||||
|
@ -19,6 +19,7 @@ import enum
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
|
||||
from jax._src.lib import mosaic_gpu_dialect as mgpu
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import func
|
||||
@ -53,8 +54,8 @@ def _choose_representative_layout(
|
||||
"""Chooses an appropriate layout from a given set of possible layouts.
|
||||
|
||||
Given the input set of possible layouts, this function extracts a single
|
||||
representative layout. Currently, this function only works with strided and
|
||||
splat fragmented layouts.
|
||||
representative layout. Currently, this function only works with strided,
|
||||
splat, and WGMMA fragmented layouts.
|
||||
|
||||
Returns:
|
||||
A single layout that can be used to annotate the operation, or None if the
|
||||
@ -65,20 +66,31 @@ def _choose_representative_layout(
|
||||
return None
|
||||
|
||||
strided_layouts: list[fa.WGStridedFragLayout] = [
|
||||
layouts_lib.from_strided_fragmented_layout_attr(layout)
|
||||
layouts_lib.from_layout_attr(layout)
|
||||
for layout in layouts
|
||||
if layouts_lib.is_strided_fragmented_layout(layout)
|
||||
]
|
||||
|
||||
splat_layouts: list[fa.WGSplatFragLayout] = list(
|
||||
map(
|
||||
layouts_lib.from_splat_fragmented_layout_attr,
|
||||
layouts_lib.from_layout_attr,
|
||||
filter(layouts_lib.is_splat_fragmented_layout, layouts),
|
||||
)
|
||||
)
|
||||
|
||||
if len(splat_layouts) + len(strided_layouts) != len(layouts):
|
||||
raise ValueError(f"Expected only strided and splat layouts, got {layouts}")
|
||||
wgmma_layouts: list[fa.WGMMAFragLayout] = list(
|
||||
map(
|
||||
layouts_lib.from_layout_attr,
|
||||
filter(layouts_lib.is_wgmma_fragmented_layout, layouts),
|
||||
)
|
||||
)
|
||||
|
||||
if len(splat_layouts) + len(strided_layouts) + len(wgmma_layouts) != len(
|
||||
layouts
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected only strided, splat, and wgmma layouts, got {layouts}"
|
||||
)
|
||||
|
||||
if len(splat_layouts) > 1:
|
||||
raise NotImplementedError(
|
||||
@ -92,14 +104,20 @@ def _choose_representative_layout(
|
||||
"is not supported."
|
||||
)
|
||||
|
||||
if not splat_layouts:
|
||||
return layouts_lib.to_strided_fragmented_layout_attr(strided_layouts[0])
|
||||
if (wgmma_layouts and strided_layouts):
|
||||
raise NotImplementedError(
|
||||
"Mixing strided and WGMMA layouts is not supported."
|
||||
)
|
||||
|
||||
if not strided_layouts:
|
||||
return layouts_lib.to_splat_fragmented_layout_attr(splat_layouts[0])
|
||||
if wgmma_layouts:
|
||||
return layouts_lib.to_layout_attr(wgmma_layouts[0])
|
||||
|
||||
[strided_layout] = strided_layouts
|
||||
return layouts_lib.to_strided_fragmented_layout_attr(strided_layout)
|
||||
if strided_layouts:
|
||||
[strided_layout] = strided_layouts
|
||||
return layouts_lib.to_layout_attr(strided_layout)
|
||||
|
||||
[splat_layout] = splat_layouts
|
||||
return layouts_lib.to_layout_attr(splat_layout)
|
||||
|
||||
|
||||
def _in_layout_for_operand(
|
||||
@ -282,6 +300,16 @@ def _infer_splat_op_layout(splat_op: vector.SplatOp) -> OptionalLayouts:
|
||||
return [], [layout]
|
||||
|
||||
|
||||
@partial(_add_layout_inference_rule, mgpu.WGMMAOp)
|
||||
def _infer_wgmma_op_layout(wgmma_op: mgpu.WGMMAOp) -> OptionalLayouts:
|
||||
layout = layouts_lib.to_layout_attr(fa.WGMMAFragLayout())
|
||||
|
||||
if ir.VectorType.isinstance(wgmma_op.a.type):
|
||||
return [layout, layout], [layout]
|
||||
|
||||
return [layout], [layout]
|
||||
|
||||
|
||||
class TraversalOrder(enum.Enum):
|
||||
"""Traversal orders with respect to the data flow for IR."""
|
||||
|
||||
|
@ -53,7 +53,7 @@ def from_splat_fragmented_layout_attr(attr: ir.Attribute) -> fa.WGSplatFragLayou
|
||||
|
||||
|
||||
def is_splat_fragmented_layout(attr: ir.Attribute) -> bool:
|
||||
return bool(re.search(_splat_fragmented_layout_attr_pattern, str(attr)))
|
||||
return bool(_splat_fragmented_layout_attr_pattern.search(str(attr)))
|
||||
|
||||
|
||||
_strided_fragmented_layout_attr_pattern = re.compile(
|
||||
@ -92,6 +92,10 @@ def from_strided_fragmented_layout_attr(
|
||||
)
|
||||
|
||||
|
||||
def is_strided_fragmented_layout(attr: ir.Attribute) -> bool:
|
||||
return bool(_strided_fragmented_layout_attr_pattern.search(str(attr)))
|
||||
|
||||
|
||||
def to_layout_attr(
|
||||
layout: (
|
||||
fa.WGSplatFragLayout
|
||||
@ -121,11 +125,19 @@ _wgmma_fragmented_layout_attr_pattern = re.compile(
|
||||
)
|
||||
|
||||
|
||||
def is_wgmma_fragmented_layout(attr: ir.Attribute) -> bool:
|
||||
return bool(_wgmma_fragmented_layout_attr_pattern.search(str(attr)))
|
||||
|
||||
|
||||
_wgmma_row_fragmented_layout_attr_pattern = re.compile(
|
||||
r"^#mosaic_gpu.WGMMARowFragLayout$"
|
||||
)
|
||||
|
||||
|
||||
def is_wgmma_row_fragmented_layout(attr: ir.Attribute) -> bool:
|
||||
return bool(_wgmma_row_fragmented_layout_attr_pattern.search(str(attr)))
|
||||
|
||||
|
||||
def from_layout_attr(
|
||||
attr: ir.Attribute,
|
||||
) -> (
|
||||
@ -135,13 +147,13 @@ def from_layout_attr(
|
||||
| fa.WGMMARowFragLayout
|
||||
):
|
||||
"""Constructs a layout from an MLIR attribute."""
|
||||
if _splat_fragmented_layout_attr_pattern.fullmatch(str(attr)):
|
||||
if is_splat_fragmented_layout(attr):
|
||||
return from_splat_fragmented_layout_attr(attr)
|
||||
elif _strided_fragmented_layout_attr_pattern.fullmatch(str(attr)):
|
||||
elif is_strided_fragmented_layout(attr):
|
||||
return from_strided_fragmented_layout_attr(attr)
|
||||
elif _wgmma_fragmented_layout_attr_pattern.fullmatch(str(attr)):
|
||||
elif is_wgmma_fragmented_layout(attr):
|
||||
return fa.WGMMAFragLayout()
|
||||
elif _wgmma_row_fragmented_layout_attr_pattern.fullmatch(str(attr)):
|
||||
elif is_wgmma_row_fragmented_layout(attr):
|
||||
return fa.WGMMARowFragLayout()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -149,10 +161,6 @@ def from_layout_attr(
|
||||
)
|
||||
|
||||
|
||||
def is_strided_fragmented_layout(attr: ir.Attribute) -> bool:
|
||||
return bool(re.search(_strided_fragmented_layout_attr_pattern, str(attr)))
|
||||
|
||||
|
||||
def in_layouts(op: ir.OpView) -> Sequence[ir.Attribute]:
|
||||
"""Returns the in_layouts attribute of the given operation.
|
||||
|
||||
|
@ -17,6 +17,7 @@
|
||||
from typing import Callable
|
||||
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import mlir as mlir_interpreter
|
||||
@ -29,8 +30,9 @@ from jax._src.lib.mlir.dialects import memref
|
||||
from jax._src.lib.mlir.dialects import nvvm
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
import jax.experimental.mosaic.gpu as mgpu
|
||||
import jax.experimental.mosaic.gpu.utils as mgpu_utils
|
||||
from jax.experimental.mosaic import gpu as mgpu
|
||||
from jax.experimental.mosaic.gpu import layouts
|
||||
from jax.experimental.mosaic.gpu import utils as mgpu_utils
|
||||
|
||||
_cext = mgpu.dialect._cext if mgpu.dialect is not None else None
|
||||
|
||||
@ -84,6 +86,8 @@ class MosaicGpuTest(parameterized.TestCase):
|
||||
def setUp(self):
|
||||
if mgpu.dialect is None:
|
||||
raise self.skipTest("Test requires Mosaic GPU dialect")
|
||||
if jax.version._version != jax.lib.__version__:
|
||||
raise self.skipTest("Test requires matching jax and jaxlib versions")
|
||||
super().setUp()
|
||||
self.enter_context(_make_ir_context())
|
||||
self.enter_context(ir.Location.unknown())
|
||||
@ -666,9 +670,7 @@ class DialectLoweringTest(MosaicGpuTest):
|
||||
ty = ir.VectorType.get(shape, elt_ty)
|
||||
load = vector.load(ty, ref, [zero_index, zero_index])
|
||||
load.owner.attributes["out_layouts"] = ir.ArrayAttr.get([
|
||||
mgpu.to_strided_fragmented_layout_attr(
|
||||
mgpu.WGStridedFragLayout.from_shaped_type(ty)
|
||||
)
|
||||
layouts.to_layout_attr(mgpu.WGStridedFragLayout.from_shaped_type(ty))
|
||||
])
|
||||
|
||||
mgpu.lower_mgpu_dialect(self.module, None)
|
||||
|
@ -15,6 +15,7 @@
|
||||
"""Layout inference tests for the Mosaic GPU MLIR dialect."""
|
||||
|
||||
from absl.testing import parameterized
|
||||
import jax
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import mlir as mlir_interpreter
|
||||
@ -24,6 +25,7 @@ from jax._src.lib.mlir.dialects import func
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
from jax._src.lib.mlir.dialects import vector
|
||||
import jax.experimental.mosaic.gpu as mgpu
|
||||
from jax.experimental.mosaic.gpu import layouts
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
@ -36,20 +38,13 @@ def _make_ir_context():
|
||||
return context
|
||||
|
||||
|
||||
def _layout_to_attr(
|
||||
layout: mgpu.WGSplatFragLayout | mgpu.WGStridedFragLayout,
|
||||
) -> ir.Attribute:
|
||||
if isinstance(layout, mgpu.WGSplatFragLayout):
|
||||
return mgpu.to_splat_fragmented_layout_attr(layout)
|
||||
else:
|
||||
return mgpu.to_strided_fragmented_layout_attr(layout)
|
||||
|
||||
|
||||
class LayoutInferenceTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if mgpu.dialect is None:
|
||||
raise self.skipTest("Test requires Mosaic GPU dialect")
|
||||
if jax.version._version != jax.lib.__version__:
|
||||
raise self.skipTest("Test requires matching jax and jaxlib versions")
|
||||
super().setUp()
|
||||
self.enter_context(_make_ir_context())
|
||||
self.enter_context(ir.Location.unknown())
|
||||
@ -72,7 +67,7 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
# strided fragmented layout.
|
||||
mgpu.infer_layout(self.module)
|
||||
|
||||
layout = mgpu.to_strided_fragmented_layout_attr(
|
||||
layout = layouts.to_layout_attr(
|
||||
mgpu.WGStridedFragLayout.from_shaped_type(ty)
|
||||
)
|
||||
|
||||
@ -95,9 +90,7 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
# splat fragmented layout.
|
||||
mgpu.infer_layout(self.module)
|
||||
|
||||
layout = mgpu.to_splat_fragmented_layout_attr(
|
||||
mgpu.WGSplatFragLayout(shape=shape)
|
||||
)
|
||||
layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape))
|
||||
|
||||
self.assertEmpty(splat0.attributes["in_layouts"])
|
||||
self.assertSequenceEqual(splat0.attributes["out_layouts"], [layout])
|
||||
@ -120,7 +113,7 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
c = arith.ConstantOp(ty, ir.DenseElementsAttr.get(attr_list, ty))
|
||||
add = arith.AddFOp(c, c)
|
||||
|
||||
layout = mgpu.to_strided_fragmented_layout_attr(
|
||||
layout = layouts.to_layout_attr(
|
||||
mgpu.WGStridedFragLayout(shape=shape, vec_size=1)
|
||||
)
|
||||
add.attributes["in_layouts"] = ir.ArrayAttr.get([layout, layout])
|
||||
@ -148,9 +141,7 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
# splat fragmented layout.
|
||||
mgpu.infer_layout(self.module)
|
||||
|
||||
layout = mgpu.to_splat_fragmented_layout_attr(
|
||||
mgpu.WGSplatFragLayout(shape=shape)
|
||||
)
|
||||
layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape=shape))
|
||||
|
||||
self.assertEmpty(splat.attributes["in_layouts"])
|
||||
self.assertSequenceEqual(splat.attributes["out_layouts"], [layout])
|
||||
@ -174,7 +165,7 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
func.FuncOp.from_py_func(ty, ty)(body)
|
||||
|
||||
[f] = self.module.body.operations
|
||||
layout_attr = _layout_to_attr(layout)
|
||||
layout_attr = layouts.to_layout_attr(layout)
|
||||
f.attributes["in_layouts"] = ir.ArrayAttr.get([layout_attr, layout_attr])
|
||||
|
||||
mgpu.infer_layout(self.module)
|
||||
@ -233,7 +224,13 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
self.assertLen(vector_store.attributes["in_layouts"], 1)
|
||||
self.assertEmpty(vector_store.attributes["out_layouts"])
|
||||
|
||||
def test_infer_layout_picks_strided_layout_over_splat_layout(self):
|
||||
@parameterized.parameters(
|
||||
mgpu.WGStridedFragLayout((32, 4), vec_size=1),
|
||||
mgpu.WGMMAFragLayout(),
|
||||
)
|
||||
def test_infer_layout_picks_non_splat_layout_over_splat_layout(
|
||||
self, layout
|
||||
):
|
||||
add = None
|
||||
|
||||
def body(lhs, rhs):
|
||||
@ -247,23 +244,20 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
|
||||
f = func.FuncOp.from_py_func(ty, ty)(body).func_op
|
||||
|
||||
splat_layout = mgpu.to_splat_fragmented_layout_attr(
|
||||
mgpu.WGSplatFragLayout(shape)
|
||||
)
|
||||
strided_layout = mgpu.to_strided_fragmented_layout_attr(
|
||||
mgpu.WGStridedFragLayout(shape, vec_size=1)
|
||||
)
|
||||
splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape))
|
||||
non_splat_layout = layouts.to_layout_attr(layout)
|
||||
|
||||
f.attributes["in_layouts"] = ir.ArrayAttr.get(
|
||||
[strided_layout, splat_layout]
|
||||
[non_splat_layout, splat_layout]
|
||||
)
|
||||
|
||||
mgpu.infer_layout(self.module)
|
||||
|
||||
self.assertSequenceEqual(
|
||||
add.attributes["in_layouts"], [strided_layout, strided_layout]
|
||||
add.attributes["in_layouts"],
|
||||
[non_splat_layout, non_splat_layout],
|
||||
)
|
||||
self.assertSequenceEqual(add.attributes["out_layouts"], [strided_layout])
|
||||
self.assertSequenceEqual(add.attributes["out_layouts"], [non_splat_layout])
|
||||
|
||||
def test_infer_layout_preserves_splat_layouts_in_producers(self):
|
||||
add0 = add1 = None
|
||||
@ -279,10 +273,8 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
ty = ir.VectorType.get(shape, elt_type)
|
||||
f = func.FuncOp.from_py_func(ty, ty)(body).func_op
|
||||
|
||||
splat_layout = mgpu.to_splat_fragmented_layout_attr(
|
||||
mgpu.WGSplatFragLayout(shape)
|
||||
)
|
||||
strided_layout = mgpu.to_strided_fragmented_layout_attr(
|
||||
splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape))
|
||||
strided_layout = layouts.to_layout_attr(
|
||||
mgpu.WGStridedFragLayout(shape, vec_size=1)
|
||||
)
|
||||
f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout])
|
||||
@ -311,9 +303,7 @@ class LayoutInferenceTest(parameterized.TestCase):
|
||||
ty = ir.VectorType.get(shape, ir.BF16Type.get())
|
||||
f = func.FuncOp.from_py_func(ty, ty)(body).func_op
|
||||
|
||||
splat_layout = mgpu.to_splat_fragmented_layout_attr(
|
||||
mgpu.WGSplatFragLayout(shape)
|
||||
)
|
||||
splat_layout = layouts.to_layout_attr(mgpu.WGSplatFragLayout(shape))
|
||||
f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout])
|
||||
mgpu.infer_layout(self.module)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user