mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
462 lines
18 KiB
TableGen
462 lines
18 KiB
TableGen
/* Copyright 2024 The JAX Authors.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_
|
|
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_
|
|
|
|
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
|
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
|
|
include "mlir/Interfaces/InferTypeOpInterface.td"
|
|
include "mlir/IR/AttrTypeBase.td"
|
|
include "mlir/IR/BuiltinAttributeInterfaces.td"
|
|
include "mlir/IR/BuiltinTypeInterfaces.td"
|
|
include "mlir/IR/CommonAttrConstraints.td"
|
|
include "mlir/IR/CommonTypeConstraints.td"
|
|
include "mlir/IR/DialectBase.td"
|
|
include "mlir/IR/EnumAttr.td"
|
|
include "mlir/IR/OpBase.td"
|
|
|
|
def MosaicGPU_Dialect : Dialect {
|
|
let name = "mosaic_gpu";
|
|
let cppNamespace = "::mosaic_gpu";
|
|
let useDefaultTypePrinterParser = 1;
|
|
let useDefaultAttributePrinterParser = 1;
|
|
}
|
|
|
|
class MosaicGPU_Type<string name, string mnemonic_, list<Trait> traits = []>
|
|
: TypeDef<MosaicGPU_Dialect, name, traits> {
|
|
let mnemonic = mnemonic_;
|
|
}
|
|
|
|
class MosaicGPU_Attr<string name, string mnemonic_, list<Trait> traits = []>
|
|
: AttrDef<MosaicGPU_Dialect, name, traits> {
|
|
let mnemonic = mnemonic_;
|
|
}
|
|
|
|
def MosaicGPU_Barrier : MosaicGPU_Type<"Barrier", "barrier", [MemRefElementTypeInterface]> {
|
|
let summary = "barrier";
|
|
let description = "A barrier to use for synchronizing threads";
|
|
}
|
|
|
|
def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;
|
|
|
|
def MosaicGPU_InitializeBarrierOp : Op<MosaicGPU_Dialect, "initialize_barrier",
|
|
[]> {
|
|
let summary = "Initializes a memref of barriers";
|
|
let description = [{
|
|
Initializes a memref of barriers each meant to synchronize exactly
|
|
`arrival_count` threads.
|
|
|
|
The base pointer of the result memref corresponds to `base_pointer`, which
|
|
must be a pointer to a shared memory location.
|
|
}];
|
|
|
|
let arguments = (ins
|
|
LLVM_PointerShared:$base_pointer,
|
|
ConfinedAttr<I64Attr, [IntPositive]>:$arrival_count);
|
|
let results = (outs MemRefOf<[MosaicGPU_Barrier]>:$barriers_ref);
|
|
|
|
let assemblyFormat = [{
|
|
$base_pointer $arrival_count attr-dict `:` type($barriers_ref)
|
|
}];
|
|
}
|
|
|
|
def MosaicGPU_ArriveExpectTxOp : Op<MosaicGPU_Dialect, "arrive_expect_tx", []> {
|
|
let summary = "Executes an arrive.expect_tx operation on the given barrier.";
|
|
let description = [{
|
|
A single thread in the warpgroup will execute an `arrive.expect_tx`
|
|
operation on the provided barrier with the provided `expect_tx`.
|
|
}];
|
|
|
|
let arguments = (ins
|
|
MemRefRankOf<[MosaicGPU_Barrier], [0]>:$barrier,
|
|
ConfinedAttr<I32Attr, [IntNonNegative]>:$expect_tx);
|
|
|
|
let assemblyFormat = [{
|
|
`barrier` `(` $barrier `:` type($barrier) `)`
|
|
$expect_tx
|
|
attr-dict
|
|
}];
|
|
}
|
|
|
|
def MosaicGPU_WaitOp : Op<MosaicGPU_Dialect, "wait", []> {
|
|
let summary = "Executes a wait operation on the given barrier.";
|
|
let description = [{
|
|
All threads in the warpgroup will block, waiting on the provided barrier
|
|
until:
|
|
- all pending threads have arrived on the barrier
|
|
- all expected byte transfers have been completed
|
|
- the barrier's parity matches the provided parity
|
|
}];
|
|
|
|
let arguments = (ins
|
|
MemRefRankOf<[MosaicGPU_Barrier], [0]>:$barrier,
|
|
I1:$parity
|
|
);
|
|
let assemblyFormat = [{
|
|
`barrier` `(` $barrier `:` type($barrier) `)`
|
|
`parity` `(` $parity `:` type($parity) `)`
|
|
attr-dict
|
|
}];
|
|
}
|
|
|
|
def MosaicGPU_WGStridedFragLayout : AttrDef<MosaicGPU_Dialect, "WGStridedFragLayout", []> {
|
|
let summary = "Annotates an array that can be collapsed to 1D and sharded across threads.";
|
|
let description = [{
|
|
This layout is typically used when working with pointwise operations, or
|
|
other operations with trivial data dependency patterns.
|
|
|
|
The layout holds the shape of the nD array it is meant to annotate, and a
|
|
vector size representing the number of contiguous elements sharded to each
|
|
thread.
|
|
}];
|
|
|
|
let parameters = (ins "::mlir::ArrayAttr":$shape, "int":$vector_size);
|
|
let mnemonic = "WGStridedFragLayout";
|
|
let assemblyFormat = "`<` $shape`,` $vector_size `>`";
|
|
}
|
|
|
|
|
|
def MosaicGPU_WGSplatFragLayout : AttrDef<MosaicGPU_Dialect, "WGSplatFragLayout", []> {
|
|
let summary = "Annotates an array that is the result of a splat.";
|
|
let description = [{
|
|
This layout is used to handle splat values. In this case, each thread in
|
|
the warpgroup gets a single copy of the value.
|
|
|
|
The layout holds the shape that the initial scalar is splatted to.
|
|
}];
|
|
|
|
let parameters = (ins "::mlir::ArrayAttr":$shape);
|
|
let mnemonic = "WGSplatFragLayout";
|
|
let assemblyFormat = "`<` $shape `>`";
|
|
}
|
|
|
|
def MosaicGPU_WGMMAFragLayout : AttrDef<MosaicGPU_Dialect, "WGMMAFragLayout", []> {
|
|
let summary = "2D array that can be tiled by supported WGMMA shapes.";
|
|
let description = [{
|
|
This layout annotates arrays that are fragmented across all threads in a
|
|
warpgroup that is executing a WGMMA operation. The shape of the array is
|
|
(m, n) where:
|
|
- m % 64 == 0
|
|
- n % 8 == 0
|
|
}];
|
|
|
|
let mnemonic = "WGMMAFragLayout";
|
|
let assemblyFormat = "";
|
|
}
|
|
|
|
def MosaicGPU_WGMMARowFragLayout : AttrDef<MosaicGPU_Dialect, "WGMMARowFragLayout", []> {
|
|
let summary = "1D array that is a row that can be tiled by supported WGMMA shapes.";
|
|
let description = [{
|
|
This layout is used to handle rows that are fragmented across all threads
|
|
in a warpgroup that is executing a WGMMA operation. The length of the array
|
|
must be divisible by 64.
|
|
}];
|
|
|
|
let mnemonic = "WGMMARowFragLayout";
|
|
let assemblyFormat = "";
|
|
}
|
|
|
|
// Note: This duplicates the Dimension enum in mlir/Dialect/GPU/IR/GPUOps.td
|
|
// but it was not possible to reuse that definition. Including that file
|
|
// pulls in ops definitions that we don't want and they fail to compile.
|
|
def MosaicGPU_Dimension : I32EnumAttr<"Dimension",
|
|
"a dimension, either 'x', 'y', or 'z'",
|
|
[
|
|
I32EnumAttrCase<"x", 0>,
|
|
I32EnumAttrCase<"y", 1>,
|
|
I32EnumAttrCase<"z", 2>
|
|
]>{
|
|
let cppNamespace = "::mosaic_gpu";
|
|
}
|
|
|
|
def MosaicGPU_SwizzlingMode : I32EnumAttr<"SwizzlingMode",
|
|
"What swizzling to use for a memory access.",
|
|
[
|
|
I32EnumAttrCase<"kNoSwizzle", 16>,
|
|
I32EnumAttrCase<"k32ByteSwizzle", 32>,
|
|
I32EnumAttrCase<"k64ByteSwizzle", 64>,
|
|
I32EnumAttrCase<"k128ByteSwizzle", 128>
|
|
]>{
|
|
let cppNamespace = "::mosaic_gpu";
|
|
}
|
|
|
|
def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> {
|
|
let parameters = (ins Variadic<I64>:$tiling);
|
|
let summary = "Specifies a transform that tiles suffix dimensions of a memref in SMEM.";
|
|
let description = [{
|
|
For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32),
|
|
the shape of the result will be (5, 2, 4, 64, 32). The shape always ends
|
|
with the tile shape, and the size of tiled dimensions is divided by the tile
|
|
size. This is especially useful for swizzled WGMMA, which expect tiled
|
|
layouts in shared memory.
|
|
|
|
Each tiled dimension must have a size that is either smaller than the
|
|
corresponding tile size or a multiple of the tile size.
|
|
}];
|
|
let assemblyFormat = "`<` $tiling `>`";
|
|
}
|
|
|
|
def TransposeTransformAttr : MosaicGPU_Attr<"TransposeTransform", "transpose"> {
|
|
let parameters = (ins Variadic<I64>:$permutation);
|
|
let summary = "Specifies a transpose transform of a memref in SMEM.";
|
|
let assemblyFormat = "`<` $permutation `>`";
|
|
}
|
|
|
|
def SwizzleTransformAttr : MosaicGPU_Attr<"SwizzleTransform", "swizzle"> {
|
|
let parameters = (ins "SwizzlingModeAttr":$swizzle);
|
|
|
|
let summary = "Specifies a swizzle transform of a memref in SMEM.";
|
|
let assemblyFormat = "`<` $swizzle `>`";
|
|
}
|
|
|
|
def LayoutAttr : MosaicGPU_Attr<"Layout", "layout",
|
|
[DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface>]> {
|
|
let parameters = (ins
|
|
TypeParameter<"int32_t", "number of dimensions">:$num_dimensions,
|
|
ArrayRefParameter<"mlir::Attribute", "transforms">:$transforms
|
|
);
|
|
|
|
let summary = "Specifies a layout of a memref in SMEM.";
|
|
let description = [{
|
|
This layout attribute is used to specify the layout of a memref in SMEM.
|
|
It is composed of a number of transforms, which are applied in the order
|
|
they are provided. The transforms can be any combination of:
|
|
- TileTransformAttr
|
|
- TransposeTransformAttr
|
|
- SwizzleTransformAttr
|
|
|
|
The num_dimensions parameter must match the rank of the memref shape.
|
|
}];
|
|
let assemblyFormat = "`<` $num_dimensions `,` $transforms `>`";
|
|
}
|
|
|
|
def GlobalMemory : Resource<"::mosaic_gpu::GlobalMemory">;
|
|
|
|
def MosaicGPU_AsyncLoadOp : Op<MosaicGPU_Dialect, "async_load",
|
|
[AttrSizedOperandSegments, MemoryEffects<[MemRead<GlobalMemory>]>]> {
|
|
let summary = "Schedules an async load of a MemRef from GMEM to SMEM";
|
|
let description = [{
|
|
Schedules an async copy of the contents of the `source` MemRef in GMEM to
|
|
the `destination` MemRef in SMEM. The `destination` MemRef in SMEM must be
|
|
contiguous.
|
|
|
|
Upon completion of the copy, the `complete-tx(complete-count)` operation
|
|
will always be executed on the provided `barrier`.
|
|
|
|
The `indices` and `slice_lengths` inputs define what slice of the GMEM
|
|
`source` corresponds to the SMEM `destination`. Both `indices` and
|
|
`slice_lengths` must have a length equal to the rank of the `source`. The
|
|
values in `indices` are the starting indices of each dimension and the
|
|
values in `slice_lengths` are the lengths. Providing -1 in `slice_lengths`
|
|
indicates that the slice length is 1 and that the corresponding dimension
|
|
should be collapsed and does not appear in the `destination` MemRef.
|
|
|
|
Setting the `layout` attribute of the `destination` MemRef to an instance of
|
|
`LayoutAttr` enables additional `transforms` that control how the `source`
|
|
data is mapped to the `destination`. The transformations will be composed in
|
|
the order they are provided. The transforms may contain up to a single
|
|
swizzle transform that controls what swizzling is applied to the data after
|
|
it is transformed, before it is finally written to SMEM. The transformed
|
|
data is written in row-major order to the contiguous SMEM `destination`.
|
|
The untransformed `source` data does not need to be contiguous, except for
|
|
the last dimension, which needs to be contiguous and the minor-most
|
|
dimension.
|
|
|
|
The `collective` attribute can be provided to use TMA multicast to more
|
|
efficiently load the GMEM data in cases where multiple thread blocks are
|
|
grouped together in a cluster and need to load the same data. Each block in
|
|
a cluster will first load a slice from GMEM to SMEM and then the slices will
|
|
be multicast to all other blocks in the cluster. In this way TMA multicast
|
|
guarnatees L2 cache hits. The `collective` attribute is the list of
|
|
cluster dimensions along which to partition the input data loads.
|
|
|
|
The `predicate` input should be set to `true` by a single thread in the
|
|
warpgroup so that it schedules the load operation. All other threads in the
|
|
warpgroup should set the `predicate` to `false`.
|
|
}];
|
|
|
|
let arguments = (ins
|
|
MemRefOf<[AnyType]>:$source,
|
|
MemRefOf<[AnyType]>:$destination,
|
|
MemRefRankOf<[MosaicGPU_Barrier], [0]>:$barrier,
|
|
Variadic<I32>:$indices,
|
|
PtxPredicate:$predicate,
|
|
|
|
// Attributes
|
|
DenseI64ArrayAttr:$slice_lengths,
|
|
TypedArrayAttrBase<MosaicGPU_Dimension, "dimensions">:$collective
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
`source` `(` $source `:` type($source) `)`
|
|
`destination` `(` $destination `:` type($destination) `)`
|
|
`barrier` `(` $barrier `:` type($barrier) `)`
|
|
`indices` `(` $indices `)`
|
|
`predicate` `(` $predicate `)`
|
|
attr-dict
|
|
}];
|
|
|
|
let hasVerifier = 1;
|
|
}
|
|
|
|
def MosaicGPU_AsyncStoreOp : Op<MosaicGPU_Dialect, "async_store",
|
|
[AttrSizedOperandSegments, MemoryEffects<[MemWrite<GlobalMemory>]>]> {
|
|
let summary = "Schedules an async store of a MemRef from SMEM to GMEM";
|
|
let description = [{
|
|
Schedules an async store of the contents of the `source` MemRef in SMEM to
|
|
the `destination` MemRef in GMEM. The `source` MemRef in SMEM must be
|
|
contiguous.
|
|
|
|
The `indices` and `slice_lengths` inputs define what slice of the GMEM
|
|
`destination` corresponds to the SMEM `source`. Both `indices` and
|
|
`slice_lengths` must have a length equal to the rank of the `destination`.
|
|
The values in `indices` are the starting indices of each dimension and the
|
|
values in `slice_lengths` are the lengths. Providing -1 in `slice_lengths`
|
|
indicates that this dimension is collapsed in the `source` and needs to be
|
|
expanded to a slice of size 1 in the `destination`.
|
|
|
|
Setting the `layout` attribute of the `source` MemRef to an instance of
|
|
`LayoutAttr` enables additional `transforms` that control how the
|
|
`destination` data in GMEM is mapped to the `source` data in SMEM. The
|
|
transformations will be composed in the order they are provided. The
|
|
transforms may contain up to a single swizzle transform that is the
|
|
swizzling mode of the `source` data in SMEM. The `source` SMEM data is
|
|
contiguous and the transformed data is written to the `destination` GMEM
|
|
which does not need to be contiguous.
|
|
|
|
The `predicate` input should be set to `true` by a single thread in the
|
|
warpgroup so that it schedules the store operation. All other threads in the
|
|
warpgroup should set the `predicate` to `false`.
|
|
}];
|
|
|
|
let arguments = (ins
|
|
MemRefOf<[AnyType]>:$source,
|
|
MemRefOf<[AnyType]>:$destination,
|
|
Variadic<I32>:$indices,
|
|
PtxPredicate:$predicate,
|
|
|
|
// Attributes
|
|
DenseI64ArrayAttr:$slice_lengths
|
|
);
|
|
|
|
let assemblyFormat = [{
|
|
`source` `(` $source `:` type($source) `)`
|
|
`destination` `(` $destination `:` type($destination) `)`
|
|
`indices` `(` $indices `)`
|
|
`predicate` `(` $predicate `)`
|
|
attr-dict
|
|
}];
|
|
|
|
let hasVerifier = 1;
|
|
}
|
|
|
|
def MosaicGPU_WGMMASupportedType : AnyTypeOf<[F16, BF16, F32],
|
|
"A type supported by the WGMMA operation">;
|
|
|
|
def MosaicGPU_WGMMALayout :
|
|
I32EnumAttr<"WGMMALayout", "The layout of the tiles of a WGMMA operation", [
|
|
I32EnumAttrCase<"RowMajor", 0>,
|
|
I32EnumAttrCase<"ColumnMajor", 1>
|
|
]> {
|
|
let cppNamespace = "::mosaic_gpu";
|
|
}
|
|
|
|
def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", [InferTypeOpInterface]> {
|
|
let summary = "Multiply two matrices asyncronously using warpgroup level matrix multiply operations.";
|
|
let description = [{
|
|
Schedules WGMMA operations that perform the following matrix multiply and
|
|
accumulate:
|
|
|
|
accumulator = a * b + accumulator
|
|
|
|
This operation supports larger inputs than the PTX-level WGMMA operation
|
|
and will schedule as many PTX-level WGMMA operations as needed to
|
|
accomplish the calculation. The `b` matrix, and optionally `a`, needs to be
|
|
provided in a 4-dimensional form, where the two minor-most dimensions
|
|
express the tile (group) size and the two major-most dimensions represent
|
|
the total number of tiles in each direction.
|
|
|
|
The inputs should have the following shapes:
|
|
- If `a` is in shared memory:
|
|
- a: [groups_m, groups_k, 64, k]
|
|
- If `a` is in registers:
|
|
- a: [groups_m * 64, groups_k * k]
|
|
- b: [groups_k, groups_n, k, k]
|
|
- accumulator: [groups_m * 64, groups_n * k]
|
|
Where:
|
|
- `k == swizzle/element_bytediwth` (for `kNoSwizzle`, `swizzle` is 16.)
|
|
|
|
The output has an identical shape and type as the input accumulator.
|
|
|
|
The `accumulator` is always in registers and `b` is always in shared memory.
|
|
The last two dimensions of any input in shared memory may be physically
|
|
transposed in memory. This is inferred from the strides of the provided
|
|
memrefs. `a` and `b` must have the same element type and when `a` is in
|
|
registers only F16 or BF16 are supported.
|
|
|
|
The `accumulator` must be a vector with a FragmentedLayout. The WGMMA
|
|
operation will be executed in the async proxy and any inputs in
|
|
registers need to be synchronized with a memory fence.
|
|
|
|
Usually `a` is read from shared memory if it is used directly in the WGMMA
|
|
operation. If `a` needs to be transfromed before it is used in the WGMMA
|
|
operation, it may be more convenient to read it directly form registers.
|
|
This avoids the need to store the data and wait for a fence.
|
|
}];
|
|
|
|
let arguments = (ins
|
|
VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>:$accumulator,
|
|
AnyTypeOf<[
|
|
MemRefOf<[MosaicGPU_WGMMASupportedType]>,
|
|
VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>]>:$a,
|
|
MemRefOf<[MosaicGPU_WGMMASupportedType]>:$b,
|
|
|
|
// Attributes
|
|
DefaultValuedAttr<MosaicGPU_SwizzlingMode, "SwizzlingMode::k128ByteSwizzle">:$swizzle
|
|
);
|
|
let results = (outs VectorOfAnyRankOf<[MosaicGPU_WGMMASupportedType]>);
|
|
|
|
let assemblyFormat = [{
|
|
`accumulator` `(` $accumulator `:` type($accumulator) `)`
|
|
`a` `(` $a `:` type($a) `)`
|
|
`b` `(` $b `:` type($b) `)`
|
|
attr-dict
|
|
`->` type(results)
|
|
}];
|
|
|
|
let extraClassDeclaration = [{
|
|
static llvm::LogicalResult inferReturnTypes(
|
|
mlir::MLIRContext *,
|
|
std::optional<mlir::Location> location,
|
|
mlir::ValueRange operands,
|
|
mlir::DictionaryAttr attributes,
|
|
mlir::OpaqueProperties properties,
|
|
mlir::RegionRange regions,
|
|
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
|
|
if (operands.empty()) {
|
|
return ::mlir::emitOptionalError(
|
|
location, "expected non-empty operands");
|
|
}
|
|
inferredReturnTypes.assign({operands[0].getType()});
|
|
return ::mlir::success();
|
|
}
|
|
}];
|
|
|
|
let hasVerifier = 1;
|
|
}
|
|
|
|
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_GPU_MOSAIC_GPU_TD_
|