rocm_jax/tests/mosaic/gpu_layout_inference_test.py
2025-01-16 07:18:35 -08:00

341 lines
11 KiB
Python

# Copyright 2024 The JAX Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layout inference tests for the Mosaic GPU MLIR dialect."""
from absl.testing import parameterized
from jax._src import config
from jax._src import test_util as jtu
from jax._src.interpreters import mlir as mlir_interpreter
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
import jax.experimental.mosaic.gpu as mgpu
config.parse_flags_with_absl()
def _make_ir_context():
context = ir.Context()
context.append_dialect_registry(mlir_interpreter.upstream_dialects)
context.load_all_available_dialects()
mgpu.dialect.register_dialect(context)
return context
def _layout_to_attr(
layout: mgpu.WGSplatFragLayout | mgpu.WGStridedFragLayout,
) -> ir.Attribute:
if isinstance(layout, mgpu.WGSplatFragLayout):
return mgpu.to_splat_fragmented_layout_attr(layout)
else:
return mgpu.to_strided_fragmented_layout_attr(layout)
class LayoutInferenceTest(parameterized.TestCase):
def setUp(self):
if mgpu.dialect is None:
raise self.skipTest("Test requires Mosaic GPU dialect")
super().setUp()
self.enter_context(_make_ir_context())
self.enter_context(ir.Location.unknown())
self.module = ir.Module.create()
def test_infer_strided_layout_default(self):
shape = (16, 8)
elt_type = ir.BF16Type.get()
add = None
def body(a, b):
nonlocal add
add = arith.AddFOp(a, b)
with ir.InsertionPoint(self.module.body):
ty = ir.VectorType.get(shape, elt_type)
func.FuncOp.from_py_func(ty, ty)(body)
# Not setting any layouts on the module should default in ops having a
# strided fragmented layout.
mgpu.infer_layout(self.module)
layout = mgpu.to_strided_fragmented_layout_attr(
mgpu.WGStridedFragLayout.from_shaped_type(ty)
)
self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout])
self.assertSequenceEqual(add.attributes["out_layouts"], [layout])
def test_infer_splat_layout_for_splat_constants(self):
shape = (16, 8)
elt_type = ir.BF16Type.get()
with ir.InsertionPoint(self.module.body):
ty = ir.VectorType.get(shape, elt_type)
c0 = ir.FloatAttr.get(elt_type, 0)
c1 = ir.FloatAttr.get(elt_type, 1)
splat0 = arith.ConstantOp(ty, ir.DenseElementsAttr.get_splat(ty, c0))
splat1 = arith.ConstantOp(ty, ir.DenseElementsAttr.get_splat(ty, c1))
add = arith.AddFOp(splat0, splat1)
# Not setting any layouts on the module should default in all ops having a
# splat fragmented layout.
mgpu.infer_layout(self.module)
layout = mgpu.to_splat_fragmented_layout_attr(
mgpu.WGSplatFragLayout(shape=shape)
)
self.assertEmpty(splat0.attributes["in_layouts"])
self.assertSequenceEqual(splat0.attributes["out_layouts"], [layout])
self.assertEmpty(splat1.attributes["in_layouts"])
self.assertSequenceEqual(splat1.attributes["out_layouts"], [layout])
self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout])
self.assertSequenceEqual(add.attributes["out_layouts"], [layout])
def test_infer_layout_from_consumer_for_non_splat_constant(self):
shape = (16, 8)
elt_type = ir.BF16Type.get()
with ir.InsertionPoint(self.module.body):
ty = ir.VectorType.get(shape, elt_type)
attr_list = [
ir.FloatAttr.get(elt_type, i) for i in range(shape[0] * shape[1])
]
c = arith.ConstantOp(ty, ir.DenseElementsAttr.get(attr_list, ty))
add = arith.AddFOp(c, c)
layout = mgpu.to_strided_fragmented_layout_attr(
mgpu.WGStridedFragLayout(shape=shape, vec_size=1)
)
add.attributes["in_layouts"] = ir.ArrayAttr.get([layout, layout])
mgpu.infer_layout(self.module)
self.assertEmpty(c.attributes["in_layouts"])
self.assertSequenceEqual(c.attributes["out_layouts"], [layout])
def test_infer_splat_layout_for_vector_splat(self):
add = splat = None
def body(lhs, rhs):
nonlocal add, splat
splat = vector.SplatOp(rhs.type, lhs)
add = arith.AddFOp(splat, rhs)
with ir.InsertionPoint(self.module.body):
shape = (16, 8)
elt_type = ir.BF16Type.get()
ty = ir.VectorType.get(shape, elt_type)
func.FuncOp.from_py_func(elt_type, ty)(body)
# Not setting any layouts on the module should default in all ops having a
# splat fragmented layout.
mgpu.infer_layout(self.module)
layout = mgpu.to_splat_fragmented_layout_attr(
mgpu.WGSplatFragLayout(shape=shape)
)
self.assertEmpty(splat.attributes["in_layouts"])
self.assertSequenceEqual(splat.attributes["out_layouts"], [layout])
self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout])
self.assertSequenceEqual(add.attributes["out_layouts"], [layout])
@parameterized.parameters(
mgpu.WGSplatFragLayout(shape=(32, 4)),
mgpu.WGStridedFragLayout(shape=(32, 4), vec_size=1),
)
def test_pointwise_op_propagates_argument_layouts(self, layout):
add = None
def body(lhs, rhs):
nonlocal add
add = arith.AddFOp(lhs, rhs)
with ir.InsertionPoint(self.module.body):
ty = ir.VectorType.get(layout.shape, ir.BF16Type.get())
func.FuncOp.from_py_func(ty, ty)(body)
[f] = self.module.body.operations
layout_attr = _layout_to_attr(layout)
f.attributes["in_layouts"] = ir.ArrayAttr.get([layout_attr, layout_attr])
mgpu.infer_layout(self.module)
self.assertSequenceEqual(
add.attributes["in_layouts"], [layout_attr, layout_attr]
)
self.assertSequenceEqual(add.attributes["out_layouts"], [layout_attr])
def test_infer_layout_traverses_ops_correctly(self):
shape = (16, 8)
elt_type = ir.BF16Type.get()
add = None
def body(a, b):
bool_type = ir.IntegerType.get_signless(1)
cst_true = arith.constant(bool_type, ir.IntegerAttr.get(bool_type, 1))
if_op = scf.IfOp(cst_true)
with ir.InsertionPoint(if_op.then_block):
nonlocal add
add = arith.AddFOp(a, b)
scf.yield_([])
with ir.InsertionPoint(self.module.body):
ab_type = ir.VectorType.get(shape, elt_type)
func.FuncOp.from_py_func(ab_type, ab_type)(body)
mgpu.infer_layout(self.module)
self.assertIn("in_layouts", add.attributes)
self.assertIn("out_layouts", add.attributes)
def test_infer_layout_has_no_layout_for_non_vector_types(self):
shape = (32, 4)
elt_ty = ir.BF16Type.get()
vector_store = None
def body(ref, array):
nonlocal vector_store
zero_index = arith.constant(ir.IndexType.get(), 0)
vector_store = vector.store(array, ref, [zero_index, zero_index])
with ir.InsertionPoint(self.module.body):
ref_ty = ir.MemRefType.get(shape, elt_ty)
array_ty = ir.VectorType.get(shape, elt_ty)
func.FuncOp.from_py_func(ref_ty, array_ty)(body)
mgpu.infer_layout(self.module)
self.assertIn("in_layouts", vector_store.attributes)
self.assertIn("out_layouts", vector_store.attributes)
# The vector store should have a layout for the input array, but not for the
# memref.
self.assertLen(vector_store.attributes["in_layouts"], 1)
self.assertEmpty(vector_store.attributes["out_layouts"])
def test_infer_layout_picks_strided_layout_over_splat_layout(self):
add = None
def body(lhs, rhs):
nonlocal add
add = arith.AddFOp(lhs, rhs)
with ir.InsertionPoint(self.module.body):
shape = (32, 4)
elt_type = ir.BF16Type.get()
ty = ir.VectorType.get(shape, elt_type)
f = func.FuncOp.from_py_func(ty, ty)(body).func_op
splat_layout = mgpu.to_splat_fragmented_layout_attr(
mgpu.WGSplatFragLayout(shape)
)
strided_layout = mgpu.to_strided_fragmented_layout_attr(
mgpu.WGStridedFragLayout(shape, vec_size=1)
)
f.attributes["in_layouts"] = ir.ArrayAttr.get(
[strided_layout, splat_layout]
)
mgpu.infer_layout(self.module)
self.assertSequenceEqual(
add.attributes["in_layouts"], [strided_layout, strided_layout]
)
self.assertSequenceEqual(add.attributes["out_layouts"], [strided_layout])
def test_infer_layout_preserves_splat_layouts_in_producers(self):
add0 = add1 = None
def body(lhs, rhs):
nonlocal add0, add1
add0 = arith.AddFOp(lhs, rhs)
add1 = arith.AddFOp(add0, add0)
with ir.InsertionPoint(self.module.body):
shape = (32, 4)
elt_type = ir.BF16Type.get()
ty = ir.VectorType.get(shape, elt_type)
f = func.FuncOp.from_py_func(ty, ty)(body).func_op
splat_layout = mgpu.to_splat_fragmented_layout_attr(
mgpu.WGSplatFragLayout(shape)
)
strided_layout = mgpu.to_strided_fragmented_layout_attr(
mgpu.WGStridedFragLayout(shape, vec_size=1)
)
f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout])
add1.attributes["out_layouts"] = ir.ArrayAttr.get([strided_layout])
mgpu.infer_layout(self.module)
self.assertSequenceEqual(
add0.attributes["in_layouts"], [splat_layout, splat_layout]
)
self.assertSequenceEqual(
add1.attributes["in_layouts"], [strided_layout, strided_layout]
)
self.assertSequenceEqual(add0.attributes["out_layouts"], [splat_layout])
self.assertSequenceEqual(add1.attributes["out_layouts"], [strided_layout])
def test_infer_layout_propagates_func_layouts_to_ops(self):
add = None
def body(lhs, rhs):
nonlocal add
add = arith.AddFOp(lhs, rhs)
with ir.InsertionPoint(self.module.body):
shape = (32, 4)
ty = ir.VectorType.get(shape, ir.BF16Type.get())
f = func.FuncOp.from_py_func(ty, ty)(body).func_op
splat_layout = mgpu.to_splat_fragmented_layout_attr(
mgpu.WGSplatFragLayout(shape)
)
f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout])
mgpu.infer_layout(self.module)
self.assertSequenceEqual(
add.attributes["in_layouts"], [splat_layout, splat_layout])
self.assertSequenceEqual(add.attributes["out_layouts"], [splat_layout])
def test_infer_layout_does_not_assign_default_layouts_to_func(self):
def body(lhs, rhs):
arith.AddFOp(lhs, rhs)
with ir.InsertionPoint(self.module.body):
shape = (32, 4)
ty = ir.VectorType.get(shape, ir.BF16Type.get())
f = func.FuncOp.from_py_func(ty, ty)(body).func_op
mgpu.infer_layout(self.module)
self.assertNotIn("in_layouts", f.attributes)
self.assertNotIn("out_layouts", f.attributes)
if __name__ == "__main__":
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())