mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[MOSAIC:GPU] Add async_load
, async_store
, and supporting attributes to the MLIR Mosaic GPU Dialect.
PiperOrigin-RevId: 694643777
This commit is contained in:
parent
7404e0d29d
commit
d833066a1f
@ -29,7 +29,9 @@ td_library(
|
||||
srcs = ["mosaic_gpu.td"],
|
||||
includes = ["."],
|
||||
deps = [
|
||||
"@llvm-project//mlir:BasicPtxBuilderIntTdFiles",
|
||||
"@llvm-project//mlir:BuiltinDialectTdFiles",
|
||||
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
|
||||
"@llvm-project//mlir:OpBaseTdFiles",
|
||||
],
|
||||
)
|
||||
@ -109,6 +111,7 @@ cc_library(
|
||||
hdrs = ["mosaic_gpu.h"],
|
||||
deps = [
|
||||
":mosaic_gpu_inc_gen",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -116,9 +119,11 @@ cc_library(
|
||||
"@llvm-project//mlir:ArithDialect",
|
||||
"@llvm-project//mlir:FuncDialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:InferTypeOpInterface",
|
||||
"@llvm-project//mlir:LLVMCommonConversion",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
"@llvm-project//mlir:MemRefDialect",
|
||||
"@llvm-project//mlir:MemRefUtils",
|
||||
"@llvm-project//mlir:SCFUtils",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@tsl//tsl/platform:statusor",
|
||||
@ -152,12 +157,19 @@ cc_test(
|
||||
gentbl_filegroup(
|
||||
name = "mosaic_gpu_python_gen_raw",
|
||||
tbl_outs = [
|
||||
(
|
||||
[
|
||||
"-gen-python-enum-bindings",
|
||||
"-bind-dialect=mosaic_gpu",
|
||||
],
|
||||
"_mosaic_gpu_gen_enums_raw.py",
|
||||
),
|
||||
(
|
||||
[
|
||||
"-gen-python-op-bindings",
|
||||
"-bind-dialect=mosaic_gpu",
|
||||
],
|
||||
"_mosaic_gpu_gen_raw.py",
|
||||
"_mosaic_gpu_gen_ops_raw.py",
|
||||
),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
@ -169,10 +181,19 @@ gentbl_filegroup(
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "mosaic_gpu_python_gen",
|
||||
srcs = ["_mosaic_gpu_gen_raw.py"],
|
||||
outs = ["_mosaic_gpu_gen.py"],
|
||||
cmd = "cat $(location _mosaic_gpu_gen_raw.py) | sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@",
|
||||
name = "mosaic_gpu_python_gen_enums",
|
||||
srcs = ["_mosaic_gpu_gen_enums_raw.py"],
|
||||
outs = ["_mosaic_gpu_gen_enums.py"],
|
||||
cmd = """
|
||||
cat $(location _mosaic_gpu_gen_enums_raw.py) | \
|
||||
sed -e 's/^from \\.\\.ir/from jaxlib\\.mlir\\.ir/g; s/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@""",
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "mosaic_gpu_python_gen_ops",
|
||||
srcs = ["_mosaic_gpu_gen_ops_raw.py"],
|
||||
outs = ["_mosaic_gpu_gen_ops.py"],
|
||||
cmd = "cat $(location _mosaic_gpu_gen_ops_raw.py) | sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' > $@",
|
||||
)
|
||||
|
||||
DIALECT_CAPI_SOURCES = [
|
||||
|
@ -18,18 +18,17 @@ limitations under the License.
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
@ -44,6 +43,12 @@ limitations under the License.
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mlir/include/mlir/IR/Diagnostics.h"
|
||||
#include "tsl/platform/statusor.h"
|
||||
|
||||
// Generated definitions.
|
||||
@ -232,11 +237,89 @@ void DeclareRuntimeFunctions(mlir::OpBuilder& builder) {
|
||||
.setVisibility(mlir::func::FuncOp::Visibility::Private);
|
||||
}
|
||||
|
||||
bool IsContiguous(mlir::MemRefType type) {
|
||||
return type.getLayout().isIdentity() ||
|
||||
(type.hasStaticShape() && type.getNumElements() > 0 &&
|
||||
mlir::memref::isStaticShapeAndContiguousRowMajor(type));
|
||||
}
|
||||
|
||||
namespace {
|
||||
llvm::LogicalResult VerifyCommonLoadStoreOp(
|
||||
mlir::Location loc, mlir::MemRefType gmem_type, absl::string_view gmem_name,
|
||||
mlir::MemRefType smem_type, absl::string_view smem_name,
|
||||
mlir::ArrayRef<int64_t> slice_lengths, int num_indices) {
|
||||
auto error = [loc](auto... params) {
|
||||
return emitError(loc, llvm::formatv(params...));
|
||||
};
|
||||
|
||||
if (!IsContiguous(smem_type)) {
|
||||
return error("The `{0}` memref must be contiguous.", smem_name);
|
||||
}
|
||||
if (gmem_type.getElementType() != smem_type.getElementType()) {
|
||||
return error(
|
||||
"The `source` and `destination` memrefs must have the same element "
|
||||
"type.");
|
||||
}
|
||||
if (absl::c_any_of(slice_lengths, [](int64_t s) { return s < -1; })) {
|
||||
return error(
|
||||
"The `slice_lengths` attribute must not contain values less than -1.");
|
||||
}
|
||||
if (gmem_type.getRank() !=
|
||||
smem_type.getRank() + absl::c_count(slice_lengths, -1)) {
|
||||
return error(
|
||||
"The rank of the `{0}` must be equal to the rank of the "
|
||||
"`{1}` plus the number of collapsed dimensions as indicated "
|
||||
"by -1 values in `slice_lengths`.",
|
||||
gmem_name, smem_name);
|
||||
}
|
||||
if (num_indices != gmem_type.getRank()) {
|
||||
return error("The size of `indices` must be equal to the rank of `{0}`.",
|
||||
gmem_name);
|
||||
}
|
||||
if (slice_lengths.size() != gmem_type.getRank()) {
|
||||
return error(
|
||||
"The size of `slice_lengths` must be equal to the rank of `{0}`.",
|
||||
gmem_name);
|
||||
}
|
||||
return llvm::success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
llvm::LogicalResult AsyncLoadOp::verify() {
|
||||
auto r = VerifyCommonLoadStoreOp(getLoc(), getSource().getType(), "source",
|
||||
getDestination().getType(), "destination",
|
||||
getSliceLengths(), getIndices().size());
|
||||
if (failed(r)) {
|
||||
return r;
|
||||
}
|
||||
|
||||
for (int i = 0; i < getCollective().size(); ++i) {
|
||||
for (int k = i + 1; k < getCollective().size(); ++k)
|
||||
if (getCollective()[i] == getCollective()[k]) {
|
||||
return emitError(
|
||||
"The `collective` attribute must not contain duplicate "
|
||||
"dimensions.");
|
||||
}
|
||||
}
|
||||
|
||||
return llvm::success();
|
||||
}
|
||||
|
||||
llvm::LogicalResult AsyncStoreOp::verify() {
|
||||
return VerifyCommonLoadStoreOp(getLoc(), getDestination().getType(),
|
||||
"destination", getSource().getType(), "source",
|
||||
getSliceLengths(), getIndices().size());
|
||||
}
|
||||
|
||||
void MosaicGPUDialect::initialize() {
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_types.cc.inc"
|
||||
>();
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_attrdefs.cc.inc"
|
||||
>();
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_ops.cc.inc"
|
||||
|
@ -19,14 +19,17 @@ limitations under the License.
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
|
||||
// Generated definitions.
|
||||
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu_dialect.h.inc" // IWYU pragma: keep
|
||||
@ -43,6 +46,10 @@ namespace mosaic_gpu {
|
||||
using Memref = ::mlir::TypedValue<::mlir::MemRefType>;
|
||||
using Pointer = ::mlir::TypedValue<::mlir::LLVM::LLVMPointerType>;
|
||||
|
||||
struct GlobalMemory : public mlir::SideEffects::Resource::Base<GlobalMemory> {
|
||||
llvm::StringRef getName() final { return "<GlobalMemory>"; }
|
||||
};
|
||||
|
||||
constexpr absl::string_view kRuntimeTmaDescriptorInitializerName =
|
||||
"mosaic_gpu_init_tma_desc";
|
||||
constexpr absl::string_view kRuntimeMemcpyAsyncH2DName =
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_
|
||||
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_
|
||||
|
||||
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/IR/BuiltinTypeInterfaces.td"
|
||||
include "mlir/IR/CommonAttrConstraints.td"
|
||||
@ -28,6 +30,7 @@ def MosaicGPU_Dialect : Dialect {
|
||||
let name = "mosaic_gpu";
|
||||
let cppNamespace = "::mosaic_gpu";
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
}
|
||||
|
||||
class MosaicGPU_Type<string name, string mnemonic_, list<Trait> traits = []>
|
||||
@ -35,6 +38,11 @@ class MosaicGPU_Type<string name, string mnemonic_, list<Trait> traits = []>
|
||||
let mnemonic = mnemonic_;
|
||||
}
|
||||
|
||||
class MosaicGPU_Attr<string name, string mnemonic_>
|
||||
: AttrDef<MosaicGPU_Dialect, name> {
|
||||
let mnemonic = mnemonic_;
|
||||
}
|
||||
|
||||
def MosaicGPU_Barrier : MosaicGPU_Type<"Barrier", "barrier", [MemRefElementTypeInterface]> {
|
||||
let summary = "barrier";
|
||||
let description = "A barrier to use for synchronizing threads";
|
||||
@ -83,4 +91,181 @@ def MosaicGPU_FragmentedLayoutAttr : EnumAttr<
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
// Note: This duplicates the Dimension enum in mlir/Dialect/GPU/IR/GPUOps.td
|
||||
// but it was not possible to reuse that definition. Including that file
|
||||
// pulls in ops definitions that we don't want and they fail to compile.
|
||||
def MosaicGPU_Dimension : I32EnumAttr<"Dimension",
|
||||
"a dimension, either 'x', 'y', or 'z'",
|
||||
[
|
||||
I32EnumAttrCase<"x", 0>,
|
||||
I32EnumAttrCase<"y", 1>,
|
||||
I32EnumAttrCase<"z", 2>
|
||||
]>{
|
||||
let cppNamespace = "::mosaic_gpu";
|
||||
let genSpecializedAttr = 0;
|
||||
}
|
||||
|
||||
def MosaicGPU_DimensionAttr : EnumAttr<MosaicGPU_Dialect, MosaicGPU_Dimension, "dim"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
def MosaicGPU_SwizzlingMode : I32EnumAttr<"SwizzlingMode",
|
||||
"What swizzling to use for a memory access.",
|
||||
[
|
||||
I32EnumAttrCase<"kNoSwizzle", 0, "none">,
|
||||
I32EnumAttrCase<"k32ByteSwizzle", 1, "32">,
|
||||
I32EnumAttrCase<"k64ByteSwizzle", 2, "64">,
|
||||
I32EnumAttrCase<"k128ByteSwizzle", 3, "128">
|
||||
]>{
|
||||
let cppNamespace = "::mosaic_gpu";
|
||||
let genSpecializedAttr = 0;
|
||||
}
|
||||
|
||||
def MosaicGPU_SwizzlingModeAttr : EnumAttr<MosaicGPU_Dialect, MosaicGPU_SwizzlingMode, "swizzle">;
|
||||
|
||||
def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> {
|
||||
let parameters = (ins Variadic<I64>:$tiling);
|
||||
let summary = "Tiles a suffix of memref dimensions.";
|
||||
let description = [{
|
||||
For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32),
|
||||
the shape of the result will be (5, 2, 4, 64, 32). The shape always ends
|
||||
with the tile shape, and the size of tiled dimensions is divided by the tile
|
||||
size. This is especially useful for swizzled WGMMA, which expect tiled
|
||||
layouts in shared memory.
|
||||
|
||||
Each tiled dimension must have a size that is either smaller than the
|
||||
corresponding tile size or a multiple of the tile size.
|
||||
}];
|
||||
let assemblyFormat = "`<` $tiling `>`";
|
||||
}
|
||||
|
||||
def TransposeTransformAttr : MosaicGPU_Attr<"TransposeTransform", "transpose"> {
|
||||
let parameters = (ins Variadic<I64>:$permutation);
|
||||
let summary = "Specifies how to transpose a memref.";
|
||||
let assemblyFormat = "`<` $permutation `>`";
|
||||
}
|
||||
|
||||
def GlobalMemory : Resource<"::mosaic_gpu::GlobalMemory">;
|
||||
|
||||
def MosaicGPU_AsyncLoadOp : Op<MosaicGPU_Dialect, "async_load",
|
||||
[AttrSizedOperandSegments, MemoryEffects<[MemRead<GlobalMemory>]>]> {
|
||||
let summary = "Schedules an async load of a MemRef from GMEM to SMEM";
|
||||
let description = [{
|
||||
Schedules an async copy of the contents of the `source` MemRef in GMEM to
|
||||
the `destination` MemRef in SMEM. The `destination` MemRef in SMEM must be
|
||||
contiguous.
|
||||
|
||||
If `arrive` is true, the `arrive.expect-tx(expect_count)` operation will be
|
||||
executed on the provided `barrier` before the copy is scheduled. Upon
|
||||
completion of the copy, the `complete-tx(complete-count)` operation will
|
||||
always be executed on the provided `barrier`.
|
||||
|
||||
The `indices` and `slice_lengths` inputs define what slice of the GMEM
|
||||
`source` corresponds to the SMEM `destination`. Both `indices` and
|
||||
`slice_lengths` must have a length equal to the rank of the `source`. The
|
||||
values in `indices` are the starting indices of each dimension and the
|
||||
values in `slice_lengths` are the lengths. Providing -1 in `slice_lengths`
|
||||
indicates that the slice length is 1 and that the corresponding dimension
|
||||
should be collapsed and does not appear in the `destination` MemRef.
|
||||
|
||||
Additional `transforms` may be provided to control how the `source` data is
|
||||
mapped to the `destination`. The transformations will be composed in the
|
||||
order they are provided. The `swizzle` attribute controls what swizzling
|
||||
is applied to the data after it is transformed, before it is finally written
|
||||
to SMEM. The transformed data is written in row-major order to the
|
||||
contiguous SMEM `destination`. The untransformed `source` data does not need
|
||||
to be contiguous, except for the last dimension, which needs to be
|
||||
contiguous and the minor-most dimension.
|
||||
|
||||
The `collective` attribute can be provided to use TMA multicast to more
|
||||
efficiently load the GMEM data in cases where multiple thread blocks are
|
||||
grouped together in a cluster and need to load the same data. Each block in
|
||||
a cluster will first load a slice from GMEM to SMEM and then the slices will
|
||||
be multicast to all other blocks in the cluster. In this way TMA multicast
|
||||
guarnatees L2 cache hits. The `collective` attribute is the list of
|
||||
cluster dimensions along which to partition the input data loads.
|
||||
|
||||
The `predicate` input should be set to `true` by a single thread in the
|
||||
warpgroup so that it schedules the load operation. All other threads in the
|
||||
warpgroup should set the `predicate` to `false`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
MemRefOf<[AnyType]>:$source,
|
||||
MemRefOf<[AnyType]>:$destination,
|
||||
MemRefRankOf<[MosaicGPU_Barrier], [0]>:$barrier,
|
||||
Variadic<I32>:$indices,
|
||||
PtxPredicate:$predicate,
|
||||
|
||||
// Attributes
|
||||
DenseI64ArrayAttr:$slice_lengths,
|
||||
TypedArrayAttrBase<AnyAttrOf<[TileTransformAttr, TransposeTransformAttr]>, "transforms">:$transforms,
|
||||
DefaultValuedAttr<MosaicGPU_SwizzlingModeAttr, "SwizzlingMode::kNoSwizzle">:$swizzle,
|
||||
DefaultValuedAttr<BoolAttr, "true" >:$arrive,
|
||||
TypedArrayAttrBase<MosaicGPU_DimensionAttr, "dimensions">:$collective
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`source` `(` $source `:` type($source) `)`
|
||||
`destination` `(` $destination `:` type($destination) `)`
|
||||
`barrier` `(` $barrier `:` type($barrier) `)`
|
||||
`indices` `(` $indices `)`
|
||||
`predicate` `(` $predicate `)`
|
||||
attr-dict
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def MosaicGPU_AsyncStoreOp : Op<MosaicGPU_Dialect, "async_store",
|
||||
[AttrSizedOperandSegments, MemoryEffects<[MemWrite<GlobalMemory>]>]> {
|
||||
let summary = "Schedules an async store of a MemRef from SMEM to GMEM";
|
||||
let description = [{
|
||||
Schedules an async store of the contents of the `source` MemRef in SMEM to
|
||||
the `destination` MemRef in GMEM. The `source` MemRef in SMEM must be
|
||||
contiguous.
|
||||
|
||||
The `indices` and `slice_lengths` inputs define what slice of the GMEM
|
||||
`destination` corresponds to the SMEM `source`. Both `indices` and
|
||||
`slice_lengths` must have a length equal to the rank of the `destination`.
|
||||
The values in `indices` are the starting indices of each dimension and the
|
||||
values in `slice_lengths` are the lengths. Providing -1 in `slice_lengths`
|
||||
indicates that this dimension is collapsed in the `source` and needs to be
|
||||
expanded to a slice of size 1 in the `destination`.
|
||||
|
||||
Additional `transforms` may be provided to control how the `destination`
|
||||
data in GMEM is mapped to the `source` data in SMEM. The transformations
|
||||
will be composed in the order they are provided. The `swizzle` attribute
|
||||
is the swizzling mode of the `source` data in SMEM. The `source` SMEM data
|
||||
is contiguous and the transformed data is written to the `destination` GMEM
|
||||
which does not need to be contiguous.
|
||||
|
||||
The `predicate` input should be set to `true` by a single thread in the
|
||||
warpgroup so that it schedules the store operation. All other threads in the
|
||||
warpgroup should set the `predicate` to `false`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
MemRefOf<[AnyType]>:$source,
|
||||
MemRefOf<[AnyType]>:$destination,
|
||||
Variadic<I32>:$indices,
|
||||
PtxPredicate:$predicate,
|
||||
|
||||
// Attributes
|
||||
DenseI64ArrayAttr:$slice_lengths,
|
||||
TypedArrayAttrBase<AnyAttrOf<[TileTransformAttr, TransposeTransformAttr]>, "transforms">:$transforms,
|
||||
DefaultValuedAttr<MosaicGPU_SwizzlingModeAttr, "SwizzlingMode::kNoSwizzle">:$swizzle
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`source` `(` $source `:` type($source) `)`
|
||||
`destination` `(` $destination `:` type($destination) `)`
|
||||
`indices` `(` $indices `)`
|
||||
`predicate` `(` $predicate `)`
|
||||
attr-dict
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_
|
||||
|
@ -21,7 +21,8 @@ py_library(
|
||||
name = "gpu_dialect",
|
||||
srcs = [
|
||||
"mosaic_gpu.py",
|
||||
"//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen.py",
|
||||
"//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen_enums.py",
|
||||
"//jaxlib/mosaic/dialect/gpu:_mosaic_gpu_gen_ops.py",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
|
@ -23,7 +23,8 @@ name. Otherwise, MLIR is unable to find the module during dialect search.
|
||||
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen import * # pylint: disable=wildcard-import # type: ignore[import-not-found]
|
||||
from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen_ops import * # pylint: disable=wildcard-import # type: ignore[import-not-found]
|
||||
from jaxlib.mosaic.dialect.gpu._mosaic_gpu_gen_enums import * # pylint: disable=wildcard-import # type: ignore[import-not-found]
|
||||
from jaxlib.mlir._mlir_libs._mosaic_gpu_ext import * # pylint: disable=wildcard-import # type: ignore[import-not-found]
|
||||
|
||||
try:
|
||||
|
@ -22,9 +22,9 @@ from jax._src import test_util as jtu
|
||||
from jax._src.interpreters import mlir as mlir_interpreter
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import arith
|
||||
from jax._src.lib.mlir.dialects import func
|
||||
from jax._src.lib.mlir.dialects import nvvm
|
||||
from jax._src.lib.mlir.dialects import scf
|
||||
|
||||
from jax.experimental.mosaic.gpu import dialect as mgpu # pylint: disable=g-importing-member
|
||||
from jax.experimental.mosaic.gpu import lower_mgpu_dialect # pylint: disable=g-importing-member,g-multiple-import
|
||||
|
||||
@ -50,12 +50,15 @@ def walk_operations(op: ir.OpView, callback):
|
||||
callback(op)
|
||||
|
||||
|
||||
def find_if(module: ir.Module,
|
||||
predicate: Callable[[ir.OpView], bool]) -> list[ir.OpView]:
|
||||
def find_if(
|
||||
module: ir.Module, predicate: Callable[[ir.OpView], bool]
|
||||
) -> list[ir.OpView]:
|
||||
result = []
|
||||
|
||||
def callback(op: ir.OpView):
|
||||
if predicate(op):
|
||||
result.append(op)
|
||||
|
||||
for op in module.body.operations:
|
||||
walk_operations(op, callback)
|
||||
return result
|
||||
@ -81,16 +84,19 @@ class DialectTest(parameterized.TestCase):
|
||||
def test_initialize_barrier_op_result_memref_must_wrap_barriers(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
mgpu.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.F32Type.get()), arrival_count=1)
|
||||
ir.MemRefType.get((1, 2), ir.F32Type.get()), arrival_count=1
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ir.MLIRError, "must be memref of barrier values"):
|
||||
ir.MLIRError, "must be memref of barrier values"
|
||||
):
|
||||
self.module.operation.verify()
|
||||
|
||||
def test_initialize_barrier_op_arrival_count_must_be_strictly_positive(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
mgpu.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
arrival_count=0)
|
||||
arrival_count=0,
|
||||
)
|
||||
with self.assertRaisesRegex(ir.MLIRError, "value is positive"):
|
||||
self.module.operation.verify()
|
||||
|
||||
@ -98,10 +104,358 @@ class DialectTest(parameterized.TestCase):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
mgpu.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
arrival_count=1)
|
||||
arrival_count=1,
|
||||
)
|
||||
self.assertTrue(self.module.operation.verify())
|
||||
self.assertIsInstance(self.module.body.operations[0],
|
||||
mgpu.InitializeBarrierOp)
|
||||
self.assertIsInstance(
|
||||
self.module.body.operations[0], mgpu.InitializeBarrierOp
|
||||
)
|
||||
|
||||
def test_async_load_op_dest_must_be_contiguous(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(
|
||||
ir.MemRefType.get([4, 8], ir.F32Type.get()),
|
||||
ir.MemRefType.get(
|
||||
[4, 8],
|
||||
ir.F32Type.get(),
|
||||
layout=ir.Attribute.parse("strided<[16, 1]>"),
|
||||
),
|
||||
ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
ir.IntegerType.get_signless(32),
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_load",
|
||||
)(
|
||||
lambda source, destination, barrier, *indices: mgpu.async_load(
|
||||
source,
|
||||
destination,
|
||||
barrier,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ir.MLIRError,
|
||||
"The `destination` memref must be contiguous",
|
||||
):
|
||||
self.module.operation.verify()
|
||||
|
||||
def test_async_load_op_source_and_dest_must_have_same_element_type(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(
|
||||
ir.MemRefType.get([4, 8], ir.F32Type.get()),
|
||||
ir.MemRefType.get([4, 8], ir.F64Type.get()),
|
||||
ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
ir.IntegerType.get_signless(32),
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_load",
|
||||
)(
|
||||
lambda source, destination, barrier, *indices: mgpu.async_load(
|
||||
source,
|
||||
destination,
|
||||
barrier,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ir.MLIRError,
|
||||
"`source` and `destination` memrefs must have the same element",
|
||||
):
|
||||
self.module.operation.verify()
|
||||
|
||||
def test_async_load_op_slice_lengths_must_be_larger_than_minus_two(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(
|
||||
ir.MemRefType.get([4, 8], ir.F32Type.get()),
|
||||
ir.MemRefType.get([4, 8], ir.F32Type.get()),
|
||||
ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
ir.IntegerType.get_signless(32),
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_load",
|
||||
)(
|
||||
lambda source, destination, barrier, *indices: mgpu.async_load(
|
||||
source,
|
||||
destination,
|
||||
barrier,
|
||||
indices,
|
||||
slice_lengths=[-2, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ir.MLIRError,
|
||||
"The `slice_lengths` attribute must not contain values less than -1",
|
||||
):
|
||||
self.module.operation.verify()
|
||||
|
||||
def test_async_load_op_source_and_dest_ranks_must_match_with_collapse(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(
|
||||
ir.MemRefType.get([1, 4, 8], ir.F32Type.get()),
|
||||
ir.MemRefType.get([4], ir.F32Type.get()),
|
||||
ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
ir.IntegerType.get_signless(32),
|
||||
ir.IntegerType.get_signless(32),
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_load",
|
||||
)(
|
||||
lambda source, destination, barrier, *indices: mgpu.async_load(
|
||||
source,
|
||||
destination,
|
||||
barrier,
|
||||
indices,
|
||||
slice_lengths=[-1, 4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ir.MLIRError,
|
||||
"`destination` plus the number of collapsed dimensions as indicated",
|
||||
):
|
||||
self.module.operation.verify()
|
||||
|
||||
def test_async_load_op_indices_size_must_match_source_rank(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(
|
||||
ir.MemRefType.get([4, 8], ir.F32Type.get()),
|
||||
ir.MemRefType.get([4, 8], ir.F32Type.get()),
|
||||
ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_load",
|
||||
)(
|
||||
lambda source, destination, barrier, *indices: mgpu.async_load(
|
||||
source,
|
||||
destination,
|
||||
barrier,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ir.MLIRError,
|
||||
"The size of `indices` must be equal to the rank of `source`",
|
||||
):
|
||||
self.module.operation.verify()
|
||||
|
||||
def test_async_load_op_slice_lengths_size_must_match_source_rank(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(
|
||||
ir.MemRefType.get([4], ir.F32Type.get()),
|
||||
ir.MemRefType.get([4], ir.F32Type.get()),
|
||||
ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_load",
|
||||
)(
|
||||
lambda source, destination, barrier, *indices: mgpu.async_load(
|
||||
source,
|
||||
destination,
|
||||
barrier,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ir.MLIRError,
|
||||
"The size of `slice_lengths` must be equal to the rank of `source`",
|
||||
):
|
||||
self.module.operation.verify()
|
||||
|
||||
def test_async_load_op_slice_collective_must_be_unique(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(
|
||||
ir.MemRefType.get([4], ir.F32Type.get()),
|
||||
ir.MemRefType.get([4], ir.F32Type.get()),
|
||||
ir.MemRefType.get([], ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_load",
|
||||
)(
|
||||
lambda source, destination, barrier, *indices: mgpu.async_load(
|
||||
source,
|
||||
destination,
|
||||
barrier,
|
||||
indices,
|
||||
slice_lengths=[4],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([
|
||||
ir.Attribute.parse(
|
||||
f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>"
|
||||
),
|
||||
ir.Attribute.parse(
|
||||
f"#mosaic_gpu.dim<{mgpu.Dimension.x.name}>"
|
||||
),
|
||||
]),
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ir.MLIRError,
|
||||
"The `collective` attribute must not contain duplicate dimensions",
|
||||
):
|
||||
self.module.operation.verify()
|
||||
|
||||
def test_async_store_op_source_must_be_contiguous(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(
|
||||
ir.MemRefType.get(
|
||||
[4, 8],
|
||||
ir.F32Type.get(),
|
||||
layout=ir.Attribute.parse("strided<[16, 1]>"),
|
||||
),
|
||||
ir.MemRefType.get([4, 8], ir.F32Type.get()),
|
||||
ir.IntegerType.get_signless(32),
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_store",
|
||||
)(
|
||||
lambda source, destination, *indices: mgpu.async_store(
|
||||
source,
|
||||
destination,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ir.MLIRError,
|
||||
"The `source` memref must be contiguous",
|
||||
):
|
||||
self.module.operation.verify()
|
||||
|
||||
def test_async_store_op_source_and_dest_must_have_same_element_type(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(
|
||||
ir.MemRefType.get([4, 8], ir.F32Type.get()),
|
||||
ir.MemRefType.get([4, 8], ir.F64Type.get()),
|
||||
ir.IntegerType.get_signless(32),
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_store",
|
||||
)(
|
||||
lambda source, destination, *indices: mgpu.async_store(
|
||||
source,
|
||||
destination,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ir.MLIRError,
|
||||
"`source` and `destination` memrefs must have the same element",
|
||||
):
|
||||
self.module.operation.verify()
|
||||
|
||||
def test_async_store_op_slice_lengths_must_be_larger_than_minus_two(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(
|
||||
ir.MemRefType.get([4, 8], ir.F32Type.get()),
|
||||
ir.MemRefType.get([4, 8], ir.F32Type.get()),
|
||||
ir.IntegerType.get_signless(32),
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_store",
|
||||
)(
|
||||
lambda source, destination, *indices: mgpu.async_store(
|
||||
source,
|
||||
destination,
|
||||
indices,
|
||||
slice_lengths=[-2, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ir.MLIRError,
|
||||
"The `slice_lengths` attribute must not contain values less than -1",
|
||||
):
|
||||
self.module.operation.verify()
|
||||
|
||||
def test_async_store_op_source_and_dest_ranks_must_match_with_collapse(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(
|
||||
ir.MemRefType.get([4], ir.F32Type.get()),
|
||||
ir.MemRefType.get([1, 4, 8], ir.F32Type.get()),
|
||||
ir.IntegerType.get_signless(32),
|
||||
ir.IntegerType.get_signless(32),
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_store",
|
||||
)(
|
||||
lambda source, destination, *indices: mgpu.async_store(
|
||||
source,
|
||||
destination,
|
||||
indices,
|
||||
slice_lengths=[-1, 4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ir.MLIRError,
|
||||
"`source` plus the number of collapsed dimensions as indicated",
|
||||
):
|
||||
self.module.operation.verify()
|
||||
|
||||
def test_async_store_op_indices_size_must_match_destination_rank(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(
|
||||
ir.MemRefType.get([4, 8], ir.F32Type.get()),
|
||||
ir.MemRefType.get([4, 8], ir.F32Type.get()),
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_store",
|
||||
)(
|
||||
lambda source, destination, *indices: mgpu.async_store(
|
||||
source,
|
||||
destination,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ir.MLIRError,
|
||||
"The size of `indices` must be equal to the rank of `destination`",
|
||||
):
|
||||
self.module.operation.verify()
|
||||
|
||||
def test_async_store_op_slice_lengths_size_must_match_source_rank(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
func.FuncOp.from_py_func(
|
||||
ir.MemRefType.get([4], ir.F32Type.get()),
|
||||
ir.MemRefType.get([4], ir.F32Type.get()),
|
||||
ir.IntegerType.get_signless(32),
|
||||
name="async_store",
|
||||
)(
|
||||
lambda source, destination, *indices: mgpu.async_store(
|
||||
source,
|
||||
destination,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ir.MLIRError,
|
||||
"The size of `slice_lengths` must be equal to the rank of"
|
||||
" `destination`",
|
||||
):
|
||||
self.module.operation.verify()
|
||||
|
||||
|
||||
class DialectLoweringTest(DialectTest):
|
||||
@ -110,11 +464,13 @@ class DialectLoweringTest(DialectTest):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
mgpu.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
arrival_count=1)
|
||||
arrival_count=1,
|
||||
)
|
||||
lower_mgpu_dialect(self.module)
|
||||
|
||||
self.assertEmpty(
|
||||
list(filter(is_mosaic_gpu_op, self.module.body.operations)))
|
||||
list(filter(is_mosaic_gpu_op, self.module.body.operations))
|
||||
)
|
||||
|
||||
def test_lowering_traverses_regions_correctly(self):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
@ -124,12 +480,14 @@ class DialectLoweringTest(DialectTest):
|
||||
with ir.InsertionPoint(if_op.then_block):
|
||||
mgpu.initialize_barrier(
|
||||
ir.MemRefType.get((1, 2), ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
arrival_count=1)
|
||||
arrival_count=1,
|
||||
)
|
||||
scf.yield_([])
|
||||
lower_mgpu_dialect(self.module)
|
||||
|
||||
self.assertEmpty(
|
||||
list(filter(is_mosaic_gpu_op, if_op.then_block.operations)))
|
||||
list(filter(is_mosaic_gpu_op, if_op.then_block.operations))
|
||||
)
|
||||
|
||||
def test_initialize_barrier_op_lowering_rule(self):
|
||||
shape = (3, 4)
|
||||
@ -139,12 +497,14 @@ class DialectLoweringTest(DialectTest):
|
||||
with ir.InsertionPoint(self.module.body):
|
||||
mgpu.initialize_barrier(
|
||||
ir.MemRefType.get(shape, ir.Type.parse("!mosaic_gpu.barrier")),
|
||||
arrival_count=arrival_count)
|
||||
arrival_count=arrival_count,
|
||||
)
|
||||
lower_mgpu_dialect(self.module)
|
||||
|
||||
all_mbarrier_init_shared_ops = find_if(
|
||||
self.module,
|
||||
lambda op: op.name == nvvm.MBarrierInitSharedOp.OPERATION_NAME)
|
||||
lambda op: op.name == nvvm.MBarrierInitSharedOp.OPERATION_NAME,
|
||||
)
|
||||
|
||||
# One nvvm.mbarrier_init_shared is issued per barrier.
|
||||
self.assertLen(all_mbarrier_init_shared_ops, num_shape_elements)
|
||||
|
Loading…
x
Reference in New Issue
Block a user