[MOSAIC:GPU] Add async_load, async_store, and supporting attributes to the MLIR Mosaic GPU Dialect.

PiperOrigin-RevId: 694643777
This commit is contained in:
Dimitar (Mitko) Asenov 2024-11-08 14:32:59 -08:00 committed by jax authors
parent 7404e0d29d
commit d833066a1f
7 changed files with 686 additions and 28 deletions

View File

@ -29,7 +29,9 @@ td_library(
srcs = ["mosaic_gpu.td"],
includes = ["."],
deps = [
"@llvm-project//mlir:BasicPtxBuilderIntTdFiles",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
],
)
@ -109,6 +111,7 @@ cc_library(
hdrs = ["mosaic_gpu.h"],
deps = [
":mosaic_gpu_inc_gen",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
@ -116,9 +119,11 @@ cc_library(
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MemRefUtils",
"@llvm-project//mlir:SCFUtils",
"@llvm-project//mlir:Support",
"@tsl//tsl/platform:statusor",
@ -152,12 +157,19 @@ cc_test(
gentbl_filegroup(
name = "mosaic_gpu_python_gen_raw",
tbl_outs = [
(
[
"-gen-python-enum-bindings",
"-bind-dialect=mosaic_gpu",
],
"_mosaic_gpu_gen_enums_raw.py",
),
(
[
"-gen-python-op-bindings",
"-bind-dialect=mosaic_gpu",
],
"_mosaic_gpu_gen_raw.py",
"_mosaic_gpu_gen_ops_raw.py",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
@ -169,10 +181,19 @@ gentbl_filegroup(
)
genrule(
name = "mosaic_gpu_python_gen",
srcs = ["_mosaic_gpu_gen_raw.py"],
outs = ["_mosaic_gpu_gen.py"],
cmd = "cat $(location _mosaic_gpu_gen_raw.py) | sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@",
name = "mosaic_gpu_python_gen_enums",
srcs = ["_mosaic_gpu_gen_enums_raw.py"],
outs = ["_mosaic_gpu_gen_enums.py"],
cmd = """
cat $(location _mosaic_gpu_gen_enums_raw.py) | \
sed -e 's/^from \\.\\.ir/from jaxlib\\.mlir\\.ir/g; s/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@""",
)
genrule(
name = "mosaic_gpu_python_gen_ops",
srcs = ["_mosaic_gpu_gen_ops_raw.py"],
outs = ["_mosaic_gpu_gen_ops.py"],
cmd = "cat $(location _mosaic_gpu_gen_ops_raw.py) | sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@",
)
DIALECT_CAPI_SOURCES = [

View File

@ -18,18 +18,17 @@ limitations under the License.
#include <cstdint>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
@ -44,6 +43,12 @@ limitations under the License.
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mlir/include/mlir/IR/Diagnostics.h"
#include "tsl/platform/statusor.h"
// Generated definitions.
@ -232,11 +237,89 @@ void DeclareRuntimeFunctions(mlir::OpBuilder& builder) {
.setVisibility(mlir::func::FuncOp::Visibility::Private);
}
bool IsContiguous(mlir::MemRefType type) {
return type.getLayout().isIdentity() ||
(type.hasStaticShape() && type.getNumElements() > 0 &&
mlir::memref::isStaticShapeAndContiguousRowMajor(type));
}
namespace {
llvm::LogicalResult VerifyCommonLoadStoreOp(
mlir::Location loc, mlir::MemRefType gmem_type, absl::string_view gmem_name,
mlir::MemRefType smem_type, absl::string_view smem_name,
mlir::ArrayRef<int64_t> slice_lengths, int num_indices) {
auto error = [loc](auto... params) {
return emitError(loc, llvm::formatv(params...));
};
if (!IsContiguous(smem_type)) {
return error("The `{0}` memref must be contiguous.", smem_name);
}
if (gmem_type.getElementType() != smem_type.getElementType()) {
return error(
"The `source` and `destination` memrefs must have the same element "
"type.");
}
if (absl::c_any_of(slice_lengths, [](int64_t s) { return s < -1; })) {
return error(
"The `slice_lengths` attribute must not contain values less than -1.");
}
if (gmem_type.getRank() !=
smem_type.getRank() + absl::c_count(slice_lengths, -1)) {
return error(
"The rank of the `{0}` must be equal to the rank of the "
"`{1}` plus the number of collapsed dimensions as indicated "
"by -1 values in `slice_lengths`.",
gmem_name, smem_name);
}
if (num_indices != gmem_type.getRank()) {
return error("The size of `indices` must be equal to the rank of `{0}`.",
gmem_name);
}
if (slice_lengths.size() != gmem_type.getRank()) {
return error(
"The size of `slice_lengths` must be equal to the rank of `{0}`.",
gmem_name);
}
return llvm::success();
}
} // namespace
llvm::LogicalResult AsyncLoadOp::verify() {
auto r = VerifyCommonLoadStoreOp(getLoc(), getSource().getType(), "source",
getDestination().getType(), "destination",
getSliceLengths(), getIndices().size());
if (failed(r)) {
return r;
}
for (int i = 0; i < getCollective().size(); ++i) {
for (int k = i + 1; k < getCollective().size(); ++k)
if (getCollective()[i] == getCollective()[k]) {
return emitError(
"The `collective` attribute must not contain duplicate "
"dimensions.");
}
}
return llvm::success();
}
llvm::LogicalResult AsyncStoreOp::verify() {
return VerifyCommonLoadStoreOp(getLoc(), getDestination().getType(),
"destination", getSource().getType(), "source",
getSliceLengths(), getIndices().size());
}
void MosaicGPUDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_types.cc.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_attrdefs.cc.inc"
>();
addOperations<
#define GET_OP_LIST
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_ops.cc.inc"

View File

@ -19,14 +19,17 @@ limitations under the License.
#include <cstdint>
#include <string>
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
// Generated definitions.
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.h.inc" // IWYU pragma: keep
@ -43,6 +46,10 @@ namespace mosaic_gpu {
using Memref = ::mlir::TypedValue<::mlir::MemRefType>;
using Pointer = ::mlir::TypedValue<::mlir::LLVM::LLVMPointerType>;
struct GlobalMemory : public mlir::SideEffects::Resource::Base<GlobalMemory> {
llvm::StringRef getName() final { return "<GlobalMemory>"; }
};
constexpr absl::string_view kRuntimeTmaDescriptorInitializerName =
"mosaic_gpu_init_tma_desc";
constexpr absl::string_view kRuntimeMemcpyAsyncH2DName =

View File

@ -16,6 +16,8 @@ limitations under the License.
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/CommonAttrConstraints.td"
@ -28,6 +30,7 @@ def MosaicGPU_Dialect : Dialect {
let name = "mosaic_gpu";
let cppNamespace = "::mosaic_gpu";
let useDefaultTypePrinterParser = 1;
let useDefaultAttributePrinterParser = 1;
}
class MosaicGPU_Type<string name, string mnemonic_, list<Trait> traits = []>
@ -35,6 +38,11 @@ class MosaicGPU_Type<string name, string mnemonic_, list<Trait> traits = []>
let mnemonic = mnemonic_;
}
class MosaicGPU_Attr<string name, string mnemonic_>
: AttrDef<MosaicGPU_Dialect, name> {
let mnemonic = mnemonic_;
}
def MosaicGPU_Barrier : MosaicGPU_Type<"Barrier", "barrier", [MemRefElementTypeInterface]> {
let summary = "barrier";
let description = "A barrier to use for synchronizing threads";
@ -83,4 +91,181 @@ def MosaicGPU_FragmentedLayoutAttr : EnumAttr<
let assemblyFormat = "`<` $value `>`";
}
// Note: This duplicates the Dimension enum in mlir/Dialect/GPU/IR/GPUOps.td
// but it was not possible to reuse that definition. Including that file
// pulls in ops definitions that we don't want and they fail to compile.
def MosaicGPU_Dimension : I32EnumAttr<"Dimension",
"a dimension, either 'x', 'y', or 'z'",
[
I32EnumAttrCase<"x", 0>,
I32EnumAttrCase<"y", 1>,
I32EnumAttrCase<"z", 2>
]>{
let cppNamespace = "::mosaic_gpu";
let genSpecializedAttr = 0;
}
def MosaicGPU_DimensionAttr : EnumAttr<MosaicGPU_Dialect, MosaicGPU_Dimension, "dim"> {
let assemblyFormat = "`<` $value `>`";
}
def MosaicGPU_SwizzlingMode : I32EnumAttr<"SwizzlingMode",
"What swizzling to use for a memory access.",
[
I32EnumAttrCase<"kNoSwizzle", 0, "none">,
I32EnumAttrCase<"k32ByteSwizzle", 1, "32">,
I32EnumAttrCase<"k64ByteSwizzle", 2, "64">,
I32EnumAttrCase<"k128ByteSwizzle", 3, "128">
]>{
let cppNamespace = "::mosaic_gpu";
let genSpecializedAttr = 0;
}
def MosaicGPU_SwizzlingModeAttr : EnumAttr<MosaicGPU_Dialect, MosaicGPU_SwizzlingMode, "swizzle">;
def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> {
let parameters = (ins Variadic<I64>:$tiling);
let summary = "Tiles a suffix of memref dimensions.";
let description = [{
For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32),
the shape of the result will be (5, 2, 4, 64, 32). The shape always ends
with the tile shape, and the size of tiled dimensions is divided by the tile
size. This is especially useful for swizzled WGMMA, which expect tiled
layouts in shared memory.
Each tiled dimension must have a size that is either smaller than the
corresponding tile size or a multiple of the tile size.
}];
let assemblyFormat = "`<` $tiling `>`";
}
def TransposeTransformAttr : MosaicGPU_Attr<"TransposeTransform", "transpose"> {
let parameters = (ins Variadic<I64>:$permutation);
let summary = "Specifies how to transpose a memref.";
let assemblyFormat = "`<` $permutation `>`";
}
def GlobalMemory : Resource<"::mosaic_gpu::GlobalMemory">;
def MosaicGPU_AsyncLoadOp : Op<MosaicGPU_Dialect, "async_load",
[AttrSizedOperandSegments, MemoryEffects<[MemRead<GlobalMemory>]>]> {
let summary = "Schedules an async load of a MemRef from GMEM to SMEM";
let description = [{
Schedules an async copy of the contents of the `source` MemRef in GMEM to
the `destination` MemRef in SMEM. The `destination` MemRef in SMEM must be
contiguous.
If `arrive` is true, the `arrive.expect-tx(expect_count)` operation will be
executed on the provided `barrier` before the copy is scheduled. Upon
completion of the copy, the `complete-tx(complete-count)` operation will
always be executed on the provided `barrier`.
The `indices` and `slice_lengths` inputs define what slice of the GMEM
`source` corresponds to the SMEM `destination`. Both `indices` and
`slice_lengths` must have a length equal to the rank of the `source`. The
values in `indices` are the starting indices of each dimension and the
values in `slice_lengths` are the lengths. Providing -1 in `slice_lengths`
indicates that the slice length is 1 and that the corresponding dimension
should be collapsed and does not appear in the `destination` MemRef.
Additional `transforms` may be provided to control how the `source` data is
mapped to the `destination`. The transformations will be composed in the
order they are provided. The `swizzle` attribute controls what swizzling
is applied to the data after it is transformed, before it is finally written
to SMEM. The transformed data is written in row-major order to the
contiguous SMEM `destination`. The untransformed `source` data does not need
to be contiguous, except for the last dimension, which needs to be
contiguous and the minor-most dimension.
The `collective` attribute can be provided to use TMA multicast to more
efficiently load the GMEM data in cases where multiple thread blocks are
grouped together in a cluster and need to load the same data. Each block in
a cluster will first load a slice from GMEM to SMEM and then the slices will
be multicast to all other blocks in the cluster. In this way TMA multicast
guarnatees L2 cache hits. The `collective` attribute is the list of
cluster dimensions along which to partition the input data loads.
The `predicate` input should be set to `true` by a single thread in the
warpgroup so that it schedules the load operation. All other threads in the
warpgroup should set the `predicate` to `false`.
}];
let arguments = (ins
MemRefOf<[AnyType]>:$source,
MemRefOf<[AnyType]>:$destination,
MemRefRankOf<[MosaicGPU_Barrier], [0]>:$barrier,
Variadic<I32>:$indices,
PtxPredicate:$predicate,
// Attributes
DenseI64ArrayAttr:$slice_lengths,
TypedArrayAttrBase<AnyAttrOf<[TileTransformAttr, TransposeTransformAttr]>, "transforms">:$transforms,
DefaultValuedAttr<MosaicGPU_SwizzlingModeAttr, "SwizzlingMode::kNoSwizzle">:$swizzle,
DefaultValuedAttr<BoolAttr, "true" >:$arrive,
TypedArrayAttrBase<MosaicGPU_DimensionAttr, "dimensions">:$collective
);
let assemblyFormat = [{
`source` `(` $source `:` type($source) `)`
`destination` `(` $destination `:` type($destination) `)`
`barrier` `(` $barrier `:` type($barrier) `)`
`indices` `(` $indices `)`
`predicate` `(` $predicate `)`
attr-dict
}];
let hasVerifier = 1;
}
def MosaicGPU_AsyncStoreOp : Op<MosaicGPU_Dialect, "async_store",
[AttrSizedOperandSegments, MemoryEffects<[MemWrite<GlobalMemory>]>]> {
let summary = "Schedules an async store of a MemRef from SMEM to GMEM";
let description = [{
Schedules an async store of the contents of the `source` MemRef in SMEM to
the `destination` MemRef in GMEM. The `source` MemRef in SMEM must be
contiguous.
The `indices` and `slice_lengths` inputs define what slice of the GMEM
`destination` corresponds to the SMEM `source`. Both `indices` and
`slice_lengths` must have a length equal to the rank of the `destination`.
The values in `indices` are the starting indices of each dimension and the
values in `slice_lengths` are the lengths. Providing -1 in `slice_lengths`
indicates that this dimension is collapsed in the `source` and needs to be
expanded to a slice of size 1 in the `destination`.
Additional `transforms` may be provided to control how the `destination`
data in GMEM is mapped to the `source` data in SMEM. The transformations
will be composed in the order they are provided. The `swizzle` attribute
is the swizzling mode of the `source` data in SMEM. The `source` SMEM data
is contiguous and the transformed data is written to the `destination` GMEM
which does not need to be contiguous.
The `predicate` input should be set to `true` by a single thread in the
warpgroup so that it schedules the store operation. All other threads in the
warpgroup should set the `predicate` to `false`.
}];
let arguments = (ins
MemRefOf<[AnyType]>:$source,
MemRefOf<[AnyType]>:$destination,
Variadic<I32>:$indices,
PtxPredicate:$predicate,
// Attributes
DenseI64ArrayAttr:$slice_lengths,
TypedArrayAttrBase<AnyAttrOf<[TileTransformAttr, TransposeTransformAttr]>, "transforms">:$transforms,
DefaultValuedAttr<MosaicGPU_SwizzlingModeAttr, "SwizzlingMode::kNoSwizzle">:$swizzle
);
let assemblyFormat = [{
`source` `(` $source `:` type($source) `)`
`destination` `(` $destination `:` type($destination) `)`
`indices` `(` $indices `)`
`predicate` `(` $predicate `)`
attr-dict
}];
let hasVerifier = 1;
}
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_

View File

@ -21,7 +21,8 @@ py_library(
name = "gpu_dialect",
srcs = [
"mosaic_gpu.py",
"//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen.py",
"//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen_enums.py",
"//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen_ops.py",
],
visibility = ["//visibility:public"],
deps = [

View File

@ -23,7 +23,8 @@ name. Otherwise, MLIR is unable to find the module during dialect search.
# pylint: disable=g-bad-import-order
from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen import * # pylint: disable=wildcard-import # type: ignore[import-not-found]
from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen_ops import * # pylint: disable=wildcard-import # type: ignore[import-not-found]
from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen_enums import * # pylint: disable=wildcard-import # type: ignore[import-not-found]
from jaxlib.mlir._mlir_libs._mosaic_gpu_ext import * # pylint: disable=wildcard-import # type: ignore[import-not-found]
try:

View File

@ -22,9 +22,9 @@ from jax._src import test_util as jtu
from jax._src.interpreters import mlir as mlir_interpreter
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import nvvm
from jax._src.lib.mlir.dialects import scf
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
from jax.experimental.mosaic.gpu import lower_mgpu_dialect # pylint: disable=g-importing-member,g-multiple-import
@ -50,12 +50,15 @@ def walk_operations(op: ir.OpView, callback):
callback(op)
def find_if(module: ir.Module,
predicate: Callable[[ir.OpView], bool]) -> list[ir.OpView]:
def find_if(
module: ir.Module, predicate: Callable[[ir.OpView], bool]
) -> list[ir.OpView]:
result = []
def callback(op: ir.OpView):
if predicate(op):
result.append(op)
for op in module.body.operations:
walk_operations(op, callback)
return result
@ -81,16 +84,19 @@ class DialectTest(parameterized.TestCase):
def test_initialize_barrier_op_result_memref_must_wrap_barriers(self):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
ir.MemRefType.get((1, 2), ir.F32Type.get()), arrival_count=1)
ir.MemRefType.get((1, 2), ir.F32Type.get()), arrival_count=1
)
with self.assertRaisesRegex(
ir.MLIRError, "must be memref of barrier values"):
ir.MLIRError, "must be memref of barrier values"
):
self.module.operation.verify()
def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
arrival_count=0)
arrival_count=0,
)
with self.assertRaisesRegex(ir.MLIRError, "value is positive"):
self.module.operation.verify()
@ -98,10 +104,358 @@ class DialectTest(parameterized.TestCase):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
arrival_count=1)
arrival_count=1,
)
self.assertTrue(self.module.operation.verify())
self.assertIsInstance(self.module.body.operations[0],
mgpu.InitializeBarrierOp)
self.assertIsInstance(
self.module.body.operations[0], mgpu.InitializeBarrierOp
)
def test_async_load_op_dest_must_be_contiguous(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.MemRefType.get([4, 8], ir.F32Type.get()),
ir.MemRefType.get(
[4, 8],
ir.F32Type.get(),
layout=ir.Attribute.parse("strided<[16, 1]>"),
),
ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")),
ir.IntegerType.get_signless(32),
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
source,
destination,
barrier,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
)
)
with self.assertRaisesRegex(
ir.MLIRError,
"The `destination` memref must be contiguous",
):
self.module.operation.verify()
def test_async_load_op_source_and_dest_must_have_same_element_type(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.MemRefType.get([4, 8], ir.F32Type.get()),
ir.MemRefType.get([4, 8], ir.F64Type.get()),
ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")),
ir.IntegerType.get_signless(32),
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
source,
destination,
barrier,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
)
)
with self.assertRaisesRegex(
ir.MLIRError,
"`source` and `destination` memrefs must have the same element",
):
self.module.operation.verify()
def test_async_load_op_slice_lengths_must_be_larger_than_minus_two(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.MemRefType.get([4, 8], ir.F32Type.get()),
ir.MemRefType.get([4, 8], ir.F32Type.get()),
ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")),
ir.IntegerType.get_signless(32),
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
source,
destination,
barrier,
indices,
slice_lengths=[-2, 8],
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
)
)
with self.assertRaisesRegex(
ir.MLIRError,
"The `slice_lengths` attribute must not contain values less than -1",
):
self.module.operation.verify()
def test_async_load_op_source_and_dest_ranks_must_match_with_collapse(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.MemRefType.get([1, 4, 8], ir.F32Type.get()),
ir.MemRefType.get([4], ir.F32Type.get()),
ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")),
ir.IntegerType.get_signless(32),
ir.IntegerType.get_signless(32),
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
source,
destination,
barrier,
indices,
slice_lengths=[-1, 4, 8],
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
)
)
with self.assertRaisesRegex(
ir.MLIRError,
"`destination` plus the number of collapsed dimensions as indicated",
):
self.module.operation.verify()
def test_async_load_op_indices_size_must_match_source_rank(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.MemRefType.get([4, 8], ir.F32Type.get()),
ir.MemRefType.get([4, 8], ir.F32Type.get()),
ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")),
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
source,
destination,
barrier,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
)
)
with self.assertRaisesRegex(
ir.MLIRError,
"The size of `indices` must be equal to the rank of `source`",
):
self.module.operation.verify()
def test_async_load_op_slice_lengths_size_must_match_source_rank(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.MemRefType.get([4], ir.F32Type.get()),
ir.MemRefType.get([4], ir.F32Type.get()),
ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")),
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
source,
destination,
barrier,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
)
)
with self.assertRaisesRegex(
ir.MLIRError,
"The size of `slice_lengths` must be equal to the rank of `source`",
):
self.module.operation.verify()
def test_async_load_op_slice_collective_must_be_unique(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.MemRefType.get([4], ir.F32Type.get()),
ir.MemRefType.get([4], ir.F32Type.get()),
ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")),
ir.IntegerType.get_signless(32),
name="async_load",
)(
lambda source, destination, barrier, *indices: mgpu.async_load(
source,
destination,
barrier,
indices,
slice_lengths=[4],
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([
ir.Attribute.parse(
f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>"
),
ir.Attribute.parse(
f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>"
),
]),
)
)
with self.assertRaisesRegex(
ir.MLIRError,
"The `collective` attribute must not contain duplicate dimensions",
):
self.module.operation.verify()
def test_async_store_op_source_must_be_contiguous(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.MemRefType.get(
[4, 8],
ir.F32Type.get(),
layout=ir.Attribute.parse("strided<[16, 1]>"),
),
ir.MemRefType.get([4, 8], ir.F32Type.get()),
ir.IntegerType.get_signless(32),
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
source,
destination,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
)
)
with self.assertRaisesRegex(
ir.MLIRError,
"The `source` memref must be contiguous",
):
self.module.operation.verify()
def test_async_store_op_source_and_dest_must_have_same_element_type(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.MemRefType.get([4, 8], ir.F32Type.get()),
ir.MemRefType.get([4, 8], ir.F64Type.get()),
ir.IntegerType.get_signless(32),
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
source,
destination,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
)
)
with self.assertRaisesRegex(
ir.MLIRError,
"`source` and `destination` memrefs must have the same element",
):
self.module.operation.verify()
def test_async_store_op_slice_lengths_must_be_larger_than_minus_two(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.MemRefType.get([4, 8], ir.F32Type.get()),
ir.MemRefType.get([4, 8], ir.F32Type.get()),
ir.IntegerType.get_signless(32),
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
source,
destination,
indices,
slice_lengths=[-2, 8],
transforms=ir.ArrayAttr.get([]),
)
)
with self.assertRaisesRegex(
ir.MLIRError,
"The `slice_lengths` attribute must not contain values less than -1",
):
self.module.operation.verify()
def test_async_store_op_source_and_dest_ranks_must_match_with_collapse(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.MemRefType.get([4], ir.F32Type.get()),
ir.MemRefType.get([1, 4, 8], ir.F32Type.get()),
ir.IntegerType.get_signless(32),
ir.IntegerType.get_signless(32),
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
source,
destination,
indices,
slice_lengths=[-1, 4, 8],
transforms=ir.ArrayAttr.get([]),
)
)
with self.assertRaisesRegex(
ir.MLIRError,
"`source` plus the number of collapsed dimensions as indicated",
):
self.module.operation.verify()
def test_async_store_op_indices_size_must_match_destination_rank(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.MemRefType.get([4, 8], ir.F32Type.get()),
ir.MemRefType.get([4, 8], ir.F32Type.get()),
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
source,
destination,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
)
)
with self.assertRaisesRegex(
ir.MLIRError,
"The size of `indices` must be equal to the rank of `destination`",
):
self.module.operation.verify()
def test_async_store_op_slice_lengths_size_must_match_source_rank(self):
with ir.InsertionPoint(self.module.body):
func.FuncOp.from_py_func(
ir.MemRefType.get([4], ir.F32Type.get()),
ir.MemRefType.get([4], ir.F32Type.get()),
ir.IntegerType.get_signless(32),
name="async_store",
)(
lambda source, destination, *indices: mgpu.async_store(
source,
destination,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
)
)
with self.assertRaisesRegex(
ir.MLIRError,
"The size of `slice_lengths` must be equal to the rank of"
" `destination`",
):
self.module.operation.verify()
class DialectLoweringTest(DialectTest):
@ -110,11 +464,13 @@ class DialectLoweringTest(DialectTest):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
arrival_count=1)
arrival_count=1,
)
lower_mgpu_dialect(self.module)
self.assertEmpty(
list(filter(is_mosaic_gpu_op, self.module.body.operations)))
list(filter(is_mosaic_gpu_op, self.module.body.operations))
)
def test_lowering_traverses_regions_correctly(self):
with ir.InsertionPoint(self.module.body):
@ -124,12 +480,14 @@ class DialectLoweringTest(DialectTest):
with ir.InsertionPoint(if_op.then_block):
mgpu.initialize_barrier(
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
arrival_count=1)
arrival_count=1,
)
scf.yield_([])
lower_mgpu_dialect(self.module)
self.assertEmpty(
list(filter(is_mosaic_gpu_op, if_op.then_block.operations)))
list(filter(is_mosaic_gpu_op, if_op.then_block.operations))
)
def test_initialize_barrier_op_lowering_rule(self):
shape = (3, 4)
@ -139,12 +497,14 @@ class DialectLoweringTest(DialectTest):
with ir.InsertionPoint(self.module.body):
mgpu.initialize_barrier(
ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")),
arrival_count=arrival_count)
arrival_count=arrival_count,
)
lower_mgpu_dialect(self.module)
all_mbarrier_init_shared_ops = find_if(
self.module,
lambda op: op.name == nvvm.MBarrierInitSharedOp.OPERATION_NAME)
lambda op: op.name == nvvm.MBarrierInitSharedOp.OPERATION_NAME,
)
# One nvvm.mbarrier_init_shared is issued per barrier.
self.assertLen(all_mbarrier_init_shared_ops, num_shape_elements)