mirror of
https://github.com/ROCm/jax.git
synced 2025-04-23 20:56:05 +00:00
341 lines
11 KiB
Python
341 lines
11 KiB
Python
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Layout inference tests for the Mosaic GPU MLIR dialect."""
|
|
|
|
from absl.testing import parameterized
|
|
from jax._src import config
|
|
from jax._src import test_util as jtu
|
|
from jax._src.interpreters import mlir as mlir_interpreter
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir.dialects import arith
|
|
from jax._src.lib.mlir.dialects import func
|
|
from jax._src.lib.mlir.dialects import scf
|
|
from jax._src.lib.mlir.dialects import vector
|
|
import jax.experimental.mosaic.gpu as mgpu
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
|
def _make_ir_context():
|
|
context = ir.Context()
|
|
context.append_dialect_registry(mlir_interpreter.upstream_dialects)
|
|
context.load_all_available_dialects()
|
|
mgpu.dialect.register_dialect(context)
|
|
return context
|
|
|
|
|
|
def _layout_to_attr(
|
|
layout: mgpu.WGSplatFragLayout | mgpu.WGStridedFragLayout,
|
|
) -> ir.Attribute:
|
|
if isinstance(layout, mgpu.WGSplatFragLayout):
|
|
return mgpu.to_splat_fragmented_layout_attr(layout)
|
|
else:
|
|
return mgpu.to_strided_fragmented_layout_attr(layout)
|
|
|
|
|
|
class LayoutInferenceTest(parameterized.TestCase):
|
|
|
|
def setUp(self):
|
|
if mgpu.dialect is None:
|
|
raise self.skipTest("Test requires Mosaic GPU dialect")
|
|
super().setUp()
|
|
self.enter_context(_make_ir_context())
|
|
self.enter_context(ir.Location.unknown())
|
|
self.module = ir.Module.create()
|
|
|
|
def test_infer_strided_layout_default(self):
|
|
shape = (16, 8)
|
|
elt_type = ir.BF16Type.get()
|
|
add = None
|
|
|
|
def body(a, b):
|
|
nonlocal add
|
|
add = arith.AddFOp(a, b)
|
|
|
|
with ir.InsertionPoint(self.module.body):
|
|
ty = ir.VectorType.get(shape, elt_type)
|
|
func.FuncOp.from_py_func(ty, ty)(body)
|
|
|
|
# Not setting any layouts on the module should default in ops having a
|
|
# strided fragmented layout.
|
|
mgpu.infer_layout(self.module)
|
|
|
|
layout = mgpu.to_strided_fragmented_layout_attr(
|
|
mgpu.WGStridedFragLayout.from_shaped_type(ty)
|
|
)
|
|
|
|
self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout])
|
|
self.assertSequenceEqual(add.attributes["out_layouts"], [layout])
|
|
|
|
def test_infer_splat_layout_for_splat_constants(self):
|
|
shape = (16, 8)
|
|
elt_type = ir.BF16Type.get()
|
|
|
|
with ir.InsertionPoint(self.module.body):
|
|
ty = ir.VectorType.get(shape, elt_type)
|
|
c0 = ir.FloatAttr.get(elt_type, 0)
|
|
c1 = ir.FloatAttr.get(elt_type, 1)
|
|
splat0 = arith.ConstantOp(ty, ir.DenseElementsAttr.get_splat(ty, c0))
|
|
splat1 = arith.ConstantOp(ty, ir.DenseElementsAttr.get_splat(ty, c1))
|
|
add = arith.AddFOp(splat0, splat1)
|
|
|
|
# Not setting any layouts on the module should default in all ops having a
|
|
# splat fragmented layout.
|
|
mgpu.infer_layout(self.module)
|
|
|
|
layout = mgpu.to_splat_fragmented_layout_attr(
|
|
mgpu.WGSplatFragLayout(shape=shape)
|
|
)
|
|
|
|
self.assertEmpty(splat0.attributes["in_layouts"])
|
|
self.assertSequenceEqual(splat0.attributes["out_layouts"], [layout])
|
|
|
|
self.assertEmpty(splat1.attributes["in_layouts"])
|
|
self.assertSequenceEqual(splat1.attributes["out_layouts"], [layout])
|
|
|
|
self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout])
|
|
self.assertSequenceEqual(add.attributes["out_layouts"], [layout])
|
|
|
|
def test_infer_layout_from_consumer_for_non_splat_constant(self):
|
|
shape = (16, 8)
|
|
elt_type = ir.BF16Type.get()
|
|
|
|
with ir.InsertionPoint(self.module.body):
|
|
ty = ir.VectorType.get(shape, elt_type)
|
|
attr_list = [
|
|
ir.FloatAttr.get(elt_type, i) for i in range(shape[0] * shape[1])
|
|
]
|
|
c = arith.ConstantOp(ty, ir.DenseElementsAttr.get(attr_list, ty))
|
|
add = arith.AddFOp(c, c)
|
|
|
|
layout = mgpu.to_strided_fragmented_layout_attr(
|
|
mgpu.WGStridedFragLayout(shape=shape, vec_size=1)
|
|
)
|
|
add.attributes["in_layouts"] = ir.ArrayAttr.get([layout, layout])
|
|
|
|
mgpu.infer_layout(self.module)
|
|
|
|
self.assertEmpty(c.attributes["in_layouts"])
|
|
self.assertSequenceEqual(c.attributes["out_layouts"], [layout])
|
|
|
|
def test_infer_splat_layout_for_vector_splat(self):
|
|
add = splat = None
|
|
|
|
def body(lhs, rhs):
|
|
nonlocal add, splat
|
|
splat = vector.SplatOp(rhs.type, lhs)
|
|
add = arith.AddFOp(splat, rhs)
|
|
|
|
with ir.InsertionPoint(self.module.body):
|
|
shape = (16, 8)
|
|
elt_type = ir.BF16Type.get()
|
|
ty = ir.VectorType.get(shape, elt_type)
|
|
func.FuncOp.from_py_func(elt_type, ty)(body)
|
|
|
|
# Not setting any layouts on the module should default in all ops having a
|
|
# splat fragmented layout.
|
|
mgpu.infer_layout(self.module)
|
|
|
|
layout = mgpu.to_splat_fragmented_layout_attr(
|
|
mgpu.WGSplatFragLayout(shape=shape)
|
|
)
|
|
|
|
self.assertEmpty(splat.attributes["in_layouts"])
|
|
self.assertSequenceEqual(splat.attributes["out_layouts"], [layout])
|
|
|
|
self.assertSequenceEqual(add.attributes["in_layouts"], [layout, layout])
|
|
self.assertSequenceEqual(add.attributes["out_layouts"], [layout])
|
|
|
|
@parameterized.parameters(
|
|
mgpu.WGSplatFragLayout(shape=(32, 4)),
|
|
mgpu.WGStridedFragLayout(shape=(32, 4), vec_size=1),
|
|
)
|
|
def test_pointwise_op_propagates_argument_layouts(self, layout):
|
|
add = None
|
|
|
|
def body(lhs, rhs):
|
|
nonlocal add
|
|
add = arith.AddFOp(lhs, rhs)
|
|
|
|
with ir.InsertionPoint(self.module.body):
|
|
ty = ir.VectorType.get(layout.shape, ir.BF16Type.get())
|
|
func.FuncOp.from_py_func(ty, ty)(body)
|
|
|
|
[f] = self.module.body.operations
|
|
layout_attr = _layout_to_attr(layout)
|
|
f.attributes["in_layouts"] = ir.ArrayAttr.get([layout_attr, layout_attr])
|
|
|
|
mgpu.infer_layout(self.module)
|
|
|
|
self.assertSequenceEqual(
|
|
add.attributes["in_layouts"], [layout_attr, layout_attr]
|
|
)
|
|
self.assertSequenceEqual(add.attributes["out_layouts"], [layout_attr])
|
|
|
|
def test_infer_layout_traverses_ops_correctly(self):
|
|
shape = (16, 8)
|
|
elt_type = ir.BF16Type.get()
|
|
add = None
|
|
|
|
def body(a, b):
|
|
bool_type = ir.IntegerType.get_signless(1)
|
|
cst_true = arith.constant(bool_type, ir.IntegerAttr.get(bool_type, 1))
|
|
if_op = scf.IfOp(cst_true)
|
|
with ir.InsertionPoint(if_op.then_block):
|
|
nonlocal add
|
|
add = arith.AddFOp(a, b)
|
|
scf.yield_([])
|
|
|
|
with ir.InsertionPoint(self.module.body):
|
|
ab_type = ir.VectorType.get(shape, elt_type)
|
|
func.FuncOp.from_py_func(ab_type, ab_type)(body)
|
|
|
|
mgpu.infer_layout(self.module)
|
|
|
|
self.assertIn("in_layouts", add.attributes)
|
|
self.assertIn("out_layouts", add.attributes)
|
|
|
|
def test_infer_layout_has_no_layout_for_non_vector_types(self):
|
|
shape = (32, 4)
|
|
elt_ty = ir.BF16Type.get()
|
|
|
|
vector_store = None
|
|
|
|
def body(ref, array):
|
|
nonlocal vector_store
|
|
zero_index = arith.constant(ir.IndexType.get(), 0)
|
|
vector_store = vector.store(array, ref, [zero_index, zero_index])
|
|
|
|
with ir.InsertionPoint(self.module.body):
|
|
ref_ty = ir.MemRefType.get(shape, elt_ty)
|
|
array_ty = ir.VectorType.get(shape, elt_ty)
|
|
func.FuncOp.from_py_func(ref_ty, array_ty)(body)
|
|
|
|
mgpu.infer_layout(self.module)
|
|
|
|
self.assertIn("in_layouts", vector_store.attributes)
|
|
self.assertIn("out_layouts", vector_store.attributes)
|
|
|
|
# The vector store should have a layout for the input array, but not for the
|
|
# memref.
|
|
self.assertLen(vector_store.attributes["in_layouts"], 1)
|
|
self.assertEmpty(vector_store.attributes["out_layouts"])
|
|
|
|
def test_infer_layout_picks_strided_layout_over_splat_layout(self):
|
|
add = None
|
|
|
|
def body(lhs, rhs):
|
|
nonlocal add
|
|
add = arith.AddFOp(lhs, rhs)
|
|
|
|
with ir.InsertionPoint(self.module.body):
|
|
shape = (32, 4)
|
|
elt_type = ir.BF16Type.get()
|
|
ty = ir.VectorType.get(shape, elt_type)
|
|
|
|
f = func.FuncOp.from_py_func(ty, ty)(body).func_op
|
|
|
|
splat_layout = mgpu.to_splat_fragmented_layout_attr(
|
|
mgpu.WGSplatFragLayout(shape)
|
|
)
|
|
strided_layout = mgpu.to_strided_fragmented_layout_attr(
|
|
mgpu.WGStridedFragLayout(shape, vec_size=1)
|
|
)
|
|
|
|
f.attributes["in_layouts"] = ir.ArrayAttr.get(
|
|
[strided_layout, splat_layout]
|
|
)
|
|
|
|
mgpu.infer_layout(self.module)
|
|
|
|
self.assertSequenceEqual(
|
|
add.attributes["in_layouts"], [strided_layout, strided_layout]
|
|
)
|
|
self.assertSequenceEqual(add.attributes["out_layouts"], [strided_layout])
|
|
|
|
def test_infer_layout_preserves_splat_layouts_in_producers(self):
|
|
add0 = add1 = None
|
|
|
|
def body(lhs, rhs):
|
|
nonlocal add0, add1
|
|
add0 = arith.AddFOp(lhs, rhs)
|
|
add1 = arith.AddFOp(add0, add0)
|
|
|
|
with ir.InsertionPoint(self.module.body):
|
|
shape = (32, 4)
|
|
elt_type = ir.BF16Type.get()
|
|
ty = ir.VectorType.get(shape, elt_type)
|
|
f = func.FuncOp.from_py_func(ty, ty)(body).func_op
|
|
|
|
splat_layout = mgpu.to_splat_fragmented_layout_attr(
|
|
mgpu.WGSplatFragLayout(shape)
|
|
)
|
|
strided_layout = mgpu.to_strided_fragmented_layout_attr(
|
|
mgpu.WGStridedFragLayout(shape, vec_size=1)
|
|
)
|
|
f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout])
|
|
add1.attributes["out_layouts"] = ir.ArrayAttr.get([strided_layout])
|
|
mgpu.infer_layout(self.module)
|
|
|
|
self.assertSequenceEqual(
|
|
add0.attributes["in_layouts"], [splat_layout, splat_layout]
|
|
)
|
|
self.assertSequenceEqual(
|
|
add1.attributes["in_layouts"], [strided_layout, strided_layout]
|
|
)
|
|
|
|
self.assertSequenceEqual(add0.attributes["out_layouts"], [splat_layout])
|
|
self.assertSequenceEqual(add1.attributes["out_layouts"], [strided_layout])
|
|
|
|
def test_infer_layout_propagates_func_layouts_to_ops(self):
|
|
add = None
|
|
|
|
def body(lhs, rhs):
|
|
nonlocal add
|
|
add = arith.AddFOp(lhs, rhs)
|
|
|
|
with ir.InsertionPoint(self.module.body):
|
|
shape = (32, 4)
|
|
ty = ir.VectorType.get(shape, ir.BF16Type.get())
|
|
f = func.FuncOp.from_py_func(ty, ty)(body).func_op
|
|
|
|
splat_layout = mgpu.to_splat_fragmented_layout_attr(
|
|
mgpu.WGSplatFragLayout(shape)
|
|
)
|
|
f.attributes["in_layouts"] = ir.ArrayAttr.get([splat_layout, splat_layout])
|
|
mgpu.infer_layout(self.module)
|
|
|
|
self.assertSequenceEqual(
|
|
add.attributes["in_layouts"], [splat_layout, splat_layout])
|
|
self.assertSequenceEqual(add.attributes["out_layouts"], [splat_layout])
|
|
|
|
def test_infer_layout_does_not_assign_default_layouts_to_func(self):
|
|
|
|
def body(lhs, rhs):
|
|
arith.AddFOp(lhs, rhs)
|
|
|
|
with ir.InsertionPoint(self.module.body):
|
|
shape = (32, 4)
|
|
ty = ir.VectorType.get(shape, ir.BF16Type.get())
|
|
f = func.FuncOp.from_py_func(ty, ty)(body).func_op
|
|
|
|
mgpu.infer_layout(self.module)
|
|
self.assertNotIn("in_layouts", f.attributes)
|
|
self.assertNotIn("out_layouts", f.attributes)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parameterized.absltest.main(testLoader=jtu.JaxTestLoader())
|