mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[Mosaic GPU] Use the memref layout to encode transforms (only swizzle for now).
Tile and Transpose transforms to follow. PiperOrigin-RevId: 725716812
This commit is contained in:
parent
c2bd1576da
commit
6fc1c61520
@ -240,12 +240,33 @@ def _vector_splat_op_lowering_rule(
|
||||
return [_fragmented_array_to_ir(fragmented_array, out_vec_ty)]
|
||||
|
||||
|
||||
def layout_to_swizzle(layout: ir.Attribute) -> mgpu.SwizzlingMode:
|
||||
"""Returns the swizzle mode for the given layout.
|
||||
|
||||
If the layout is not a LayoutAttr, the swizzle is kNoSwizzle. Otherwise,
|
||||
the layout must consist of exactly one swizzle transform.
|
||||
"""
|
||||
if mgpu.LayoutAttr.isinstance(layout):
|
||||
transforms = mgpu.LayoutAttr(layout).transforms
|
||||
if len(transforms) != 1:
|
||||
raise ValueError(f"{layout} has multiple transforms")
|
||||
if not mgpu.SwizzleTransformAttr.isinstance(transforms[0]):
|
||||
raise NotImplementedError("Only siwzzle transforms are supported.")
|
||||
# TODO(dasenov): Swizzling can change if the ref is sliced in certain
|
||||
# ways. We might want to enforce some restrictions here.
|
||||
return mgpu.SwizzleTransformAttr(transforms[0]).swizzle
|
||||
|
||||
return mgpu.SwizzlingMode.kNoSwizzle
|
||||
|
||||
|
||||
@_register_lowering(mgpu.AsyncLoadOp)
|
||||
def _mgpu_async_load_op_lowering_rule(
|
||||
ctx: LoweringContext, load_op: mgpu.AsyncLoadOp
|
||||
) -> Sequence[ir.Value]:
|
||||
assert ctx.launch_context is not None
|
||||
barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier)
|
||||
|
||||
dst_layout = ir.MemRefType(load_op.destination.type).layout
|
||||
# TODO(dasenov): Add support for the remaining op properties.
|
||||
ctx.launch_context.async_copy(
|
||||
src_ref=load_op.source,
|
||||
@ -253,7 +274,7 @@ def _mgpu_async_load_op_lowering_rule(
|
||||
barrier=barrier,
|
||||
arrive=False,
|
||||
uniform=True,
|
||||
swizzle=load_op.swizzle.value,
|
||||
swizzle=layout_to_swizzle(dst_layout),
|
||||
predicate=ctx.single_thread_per_warpgroup_predicate,
|
||||
)
|
||||
return []
|
||||
@ -264,11 +285,13 @@ def _mgpu_async_store_op_lowering_rule(
|
||||
ctx: LoweringContext, store_op: mgpu.AsyncStoreOp
|
||||
) -> Sequence[ir.Value]:
|
||||
assert ctx.launch_context is not None
|
||||
|
||||
src_layout = ir.MemRefType(store_op.source.type).layout
|
||||
# TODO(dasenov): Add support for the remaining op properties.
|
||||
ctx.launch_context.async_copy(
|
||||
src_ref=store_op.source,
|
||||
dst_ref=store_op.destination,
|
||||
swizzle=store_op.swizzle.value,
|
||||
swizzle=layout_to_swizzle(src_layout),
|
||||
uniform=True,
|
||||
predicate=ctx.single_thread_per_warpgroup_predicate,
|
||||
)
|
||||
|
@ -13,9 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h"
|
||||
#include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
@ -31,4 +35,39 @@ NB_MODULE(_mosaic_gpu_ext, m) {
|
||||
}
|
||||
},
|
||||
nb::arg("context"), nb::arg("load") = true);
|
||||
|
||||
mlir::python::nanobind_adaptors::mlir_attribute_subclass(
|
||||
m, "SwizzleTransformAttr", MosaicGpuIsASwizzleTransformAttr)
|
||||
.def_classmethod(
|
||||
"get",
|
||||
[](nb::object cls, int32_t swizzle, MlirContext ctx) {
|
||||
return cls(MosaicGpuSwizzleTransformAttrGet(
|
||||
ctx, static_cast<int32_t>(swizzle)));
|
||||
},
|
||||
nb::arg("cls"), nb::arg("swizzle"),
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Creates a SwizzleTransformAttr with the given swizzle.")
|
||||
.def_property_readonly("swizzle", [](MlirAttribute self) {
|
||||
return MosaicGpuSwizzleTransformAttrGetSwizzle(self);
|
||||
});
|
||||
|
||||
mlir::python::nanobind_adaptors::mlir_attribute_subclass(
|
||||
m, "LayoutAttr", MosaicGpuIsALayoutAttr)
|
||||
.def_classmethod(
|
||||
"get",
|
||||
[](nb::object cls, int32_t num_dimensions,
|
||||
std::vector<MlirAttribute>& transforms, MlirContext ctx) {
|
||||
return cls(MosaicGpuLayoutAttrGet(
|
||||
ctx, num_dimensions, transforms.data(), transforms.size()));
|
||||
},
|
||||
nb::arg("cls"), nb::arg("num_dimensions"), nb::arg("transforms"),
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Creates a LayoutAttr with the given transforms.")
|
||||
.def_property_readonly("transforms", [](MlirAttribute self) {
|
||||
std::vector<MlirAttribute> result;
|
||||
for (int i = 0; i < MosaicGpuLayoutAttrGetTransformsSize(self); ++i) {
|
||||
result.push_back(MosaicGpuLayoutAttrGetTransform(self, i));
|
||||
}
|
||||
return result;
|
||||
});
|
||||
}
|
||||
|
@ -198,10 +198,12 @@ genrule(
|
||||
)
|
||||
|
||||
DIALECT_CAPI_SOURCES = [
|
||||
":integrations/c/attributes.cc",
|
||||
":integrations/c/gpu_dialect.cc",
|
||||
]
|
||||
|
||||
DIALECT_CAPI_HEADERS = [
|
||||
":integrations/c/attributes.h",
|
||||
":integrations/c/gpu_dialect.h",
|
||||
]
|
||||
|
||||
@ -212,7 +214,10 @@ cc_library(
|
||||
deps = [
|
||||
":mosaic_gpu",
|
||||
":mosaic_gpu_inc_gen",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:CAPIIR",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -234,7 +239,10 @@ cc_library(
|
||||
deps = [
|
||||
":mosaic_gpu",
|
||||
":mosaic_gpu_inc_gen",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:CAPIIRObjects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
62
jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc
Normal file
62
jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc
Normal file
@ -0,0 +1,62 @@
|
||||
#include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SwizzleTransformAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
bool MosaicGpuIsASwizzleTransformAttr(MlirAttribute attr) {
|
||||
return mlir::isa<mosaic_gpu::SwizzleTransformAttr>(unwrap(attr));
|
||||
}
|
||||
MlirAttribute MosaicGpuSwizzleTransformAttrGet(MlirContext ctx,
|
||||
int32_t swizzle) {
|
||||
return wrap(mosaic_gpu::SwizzleTransformAttr::get(
|
||||
unwrap(ctx),
|
||||
mosaic_gpu::SwizzlingModeAttr::get(
|
||||
unwrap(ctx), static_cast<mosaic_gpu::SwizzlingMode>(swizzle))));
|
||||
}
|
||||
int32_t MosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr) {
|
||||
return static_cast<int32_t>(
|
||||
mlir::cast<mosaic_gpu::SwizzleTransformAttr>(unwrap(attr))
|
||||
.getSwizzle()
|
||||
.getValue());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LayoutAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool MosaicGpuIsALayoutAttr(MlirAttribute attr) {
|
||||
return mlir::isa<mosaic_gpu::LayoutAttr>(unwrap(attr));
|
||||
}
|
||||
|
||||
MlirAttribute MosaicGpuLayoutAttrGet(MlirContext ctx, int32_t num_dimensions,
|
||||
MlirAttribute* transforms,
|
||||
int32_t transforms_size) {
|
||||
std::vector<mlir::Attribute> unwrapped_transforms;
|
||||
unwrapped_transforms.reserve(transforms_size);
|
||||
for (int i = 0; i < transforms_size; ++i) {
|
||||
unwrapped_transforms.push_back(unwrap(transforms[i]));
|
||||
}
|
||||
return wrap(mosaic_gpu::LayoutAttr::get(unwrap(ctx), num_dimensions,
|
||||
unwrapped_transforms));
|
||||
}
|
||||
|
||||
int32_t MosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr) {
|
||||
return mlir::cast<mosaic_gpu::LayoutAttr>(unwrap(attr))
|
||||
.getTransforms()
|
||||
.size();
|
||||
}
|
||||
|
||||
MlirAttribute MosaicGpuLayoutAttrGetTransform(MlirAttribute attr,
|
||||
int32_t index) {
|
||||
return wrap(
|
||||
mlir::cast<mosaic_gpu::LayoutAttr>(unwrap(attr)).getTransforms()[index]);
|
||||
}
|
60
jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h
Normal file
60
jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h
Normal file
@ -0,0 +1,60 @@
|
||||
/* Copyright 2025 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_ATTRIBUTES_H_
|
||||
#define JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_ATTRIBUTES_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SwizzleTransformAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MLIR_CAPI_EXPORTED bool MosaicGpuIsASwizzleTransformAttr(MlirAttribute attr);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirAttribute
|
||||
MosaicGpuSwizzleTransformAttrGet(MlirContext ctx, int32_t swizzle);
|
||||
|
||||
MLIR_CAPI_EXPORTED int32_t
|
||||
MosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LayoutAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MLIR_CAPI_EXPORTED bool MosaicGpuIsALayoutAttr(MlirAttribute attr);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirAttribute
|
||||
MosaicGpuLayoutAttrGet(MlirContext ctx, int32_t num_dimensions,
|
||||
MlirAttribute* transforms, int32_t transforms_size);
|
||||
|
||||
MLIR_CAPI_EXPORTED int32_t
|
||||
MosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirAttribute
|
||||
MosaicGpuLayoutAttrGetTransform(MlirAttribute attr, int32_t index);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_ATTRIBUTES_H_
|
@ -485,6 +485,14 @@ llvm::LogicalResult WGMMAOp::verify() {
|
||||
return llvm::success();
|
||||
}
|
||||
|
||||
mlir::AffineMap LayoutAttr::getAffineMap() const {
|
||||
// This always returns an identity map. It's technically not correct, but we
|
||||
// don't actually use it anywhere. It's only called during verification of the
|
||||
// layout attribute and needs to be semi-valid.
|
||||
return mlir::AffineMap::getMultiDimIdentityMap(getNumDimensions(),
|
||||
getContext());
|
||||
}
|
||||
|
||||
void MosaicGPUDialect::initialize() {
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
|
@ -20,6 +20,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/IR/AttrTypeBase.td"
|
||||
include "mlir/IR/BuiltinAttributeInterfaces.td"
|
||||
include "mlir/IR/BuiltinTypeInterfaces.td"
|
||||
include "mlir/IR/CommonAttrConstraints.td"
|
||||
include "mlir/IR/CommonTypeConstraints.td"
|
||||
@ -39,8 +40,8 @@ class MosaicGPU_Type<string name, string mnemonic_, list<Trait> traits = []>
|
||||
let mnemonic = mnemonic_;
|
||||
}
|
||||
|
||||
class MosaicGPU_Attr<string name, string mnemonic_>
|
||||
: AttrDef<MosaicGPU_Dialect, name> {
|
||||
class MosaicGPU_Attr<string name, string mnemonic_, list<Trait> traits = []>
|
||||
: AttrDef<MosaicGPU_Dialect, name, traits> {
|
||||
let mnemonic = mnemonic_;
|
||||
}
|
||||
|
||||
@ -194,7 +195,7 @@ def MosaicGPU_SwizzlingMode : I32EnumAttr<"SwizzlingMode",
|
||||
|
||||
def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> {
|
||||
let parameters = (ins Variadic<I64>:$tiling);
|
||||
let summary = "Tiles a suffix of memref dimensions.";
|
||||
let summary = "Specifies a transform that tiles suffix dimensions of a memref in SMEM.";
|
||||
let description = [{
|
||||
For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32),
|
||||
the shape of the result will be (5, 2, 4, 64, 32). The shape always ends
|
||||
@ -210,10 +211,38 @@ def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> {
|
||||
|
||||
def TransposeTransformAttr : MosaicGPU_Attr<"TransposeTransform", "transpose"> {
|
||||
let parameters = (ins Variadic<I64>:$permutation);
|
||||
let summary = "Specifies how to transpose a memref.";
|
||||
let summary = "Specifies a transpose transform of a memref in SMEM.";
|
||||
let assemblyFormat = "`<` $permutation `>`";
|
||||
}
|
||||
|
||||
def SwizzleTransformAttr : MosaicGPU_Attr<"SwizzleTransform", "swizzle"> {
|
||||
let parameters = (ins "SwizzlingModeAttr":$swizzle);
|
||||
|
||||
let summary = "Specifies a swizzle transform of a memref in SMEM.";
|
||||
let assemblyFormat = "`<` $swizzle `>`";
|
||||
}
|
||||
|
||||
def LayoutAttr : MosaicGPU_Attr<"Layout", "layout",
|
||||
[DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface>]> {
|
||||
let parameters = (ins
|
||||
TypeParameter<"int32_t", "number of dimensions">:$num_dimensions,
|
||||
ArrayRefParameter<"mlir::Attribute", "transforms">:$transforms
|
||||
);
|
||||
|
||||
let summary = "Specifies a layout of a memref in SMEM.";
|
||||
let description = [{
|
||||
This layout attribute is used to specify the layout of a memref in SMEM.
|
||||
It is composed of a number of transforms, which are applied in the order
|
||||
they are provided. The transforms can be any combination of:
|
||||
- TileTransformAttr
|
||||
- TransposeTransformAttr
|
||||
- SwizzleTransformAttr
|
||||
|
||||
The num_dimensions parameter must match the rank of the memref shape.
|
||||
}];
|
||||
let assemblyFormat = "`<` $num_dimensions `,` $transforms `>`";
|
||||
}
|
||||
|
||||
def GlobalMemory : Resource<"::mosaic_gpu::GlobalMemory">;
|
||||
|
||||
def MosaicGPU_AsyncLoadOp : Op<MosaicGPU_Dialect, "async_load",
|
||||
@ -235,14 +264,16 @@ def MosaicGPU_AsyncLoadOp : Op<MosaicGPU_Dialect, "async_load",
|
||||
indicates that the slice length is 1 and that the corresponding dimension
|
||||
should be collapsed and does not appear in the `destination` MemRef.
|
||||
|
||||
Additional `transforms` may be provided to control how the `source` data is
|
||||
mapped to the `destination`. The transformations will be composed in the
|
||||
order they are provided. The `swizzle` attribute controls what swizzling
|
||||
is applied to the data after it is transformed, before it is finally written
|
||||
to SMEM. The transformed data is written in row-major order to the
|
||||
contiguous SMEM `destination`. The untransformed `source` data does not need
|
||||
to be contiguous, except for the last dimension, which needs to be
|
||||
contiguous and the minor-most dimension.
|
||||
Setting the `layout` attribute of the `destination` MemRef to an instance of
|
||||
`LayoutAttr` enables additional `transforms` that control how the `source`
|
||||
data is mapped to the `destination`. The transformations will be composed in
|
||||
the order they are provided. The transforms may contain up to a single
|
||||
swizzle transform that controls what swizzling is applied to the data after
|
||||
it is transformed, before it is finally written to SMEM. The transformed
|
||||
data is written in row-major order to the contiguous SMEM `destination`.
|
||||
The untransformed `source` data does not need to be contiguous, except for
|
||||
the last dimension, which needs to be contiguous and the minor-most
|
||||
dimension.
|
||||
|
||||
The `collective` attribute can be provided to use TMA multicast to more
|
||||
efficiently load the GMEM data in cases where multiple thread blocks are
|
||||
@ -266,8 +297,6 @@ def MosaicGPU_AsyncLoadOp : Op<MosaicGPU_Dialect, "async_load",
|
||||
|
||||
// Attributes
|
||||
DenseI64ArrayAttr:$slice_lengths,
|
||||
TypedArrayAttrBase<AnyAttrOf<[TileTransformAttr, TransposeTransformAttr]>, "transforms">:$transforms,
|
||||
DefaultValuedAttr<MosaicGPU_SwizzlingMode, "SwizzlingMode::kNoSwizzle">:$swizzle,
|
||||
TypedArrayAttrBase<MosaicGPU_Dimension, "dimensions">:$collective
|
||||
);
|
||||
|
||||
@ -299,11 +328,13 @@ def MosaicGPU_AsyncStoreOp : Op<MosaicGPU_Dialect, "async_store",
|
||||
indicates that this dimension is collapsed in the `source` and needs to be
|
||||
expanded to a slice of size 1 in the `destination`.
|
||||
|
||||
Additional `transforms` may be provided to control how the `destination`
|
||||
data in GMEM is mapped to the `source` data in SMEM. The transformations
|
||||
will be composed in the order they are provided. The `swizzle` attribute
|
||||
is the swizzling mode of the `source` data in SMEM. The `source` SMEM data
|
||||
is contiguous and the transformed data is written to the `destination` GMEM
|
||||
Setting the `layout` attribute of the `source` MemRef to an instance of
|
||||
`LayoutAttr` enables additional `transforms` that control how the
|
||||
`destination` data in GMEM is mapped to the `source` data in SMEM. The
|
||||
transformations will be composed in the order they are provided. The
|
||||
transforms may contain up to a single swizzle transform that is the
|
||||
swizzling mode of the `source` data in SMEM. The `source` SMEM data is
|
||||
contiguous and the transformed data is written to the `destination` GMEM
|
||||
which does not need to be contiguous.
|
||||
|
||||
The `predicate` input should be set to `true` by a single thread in the
|
||||
@ -318,9 +349,7 @@ def MosaicGPU_AsyncStoreOp : Op<MosaicGPU_Dialect, "async_store",
|
||||
PtxPredicate:$predicate,
|
||||
|
||||
// Attributes
|
||||
DenseI64ArrayAttr:$slice_lengths,
|
||||
TypedArrayAttrBase<AnyAttrOf<[TileTransformAttr, TransposeTransformAttr]>, "transforms">:$transforms,
|
||||
DefaultValuedAttr<MosaicGPU_SwizzlingMode, "SwizzlingMode::kNoSwizzle">:$swizzle
|
||||
DenseI64ArrayAttr:$slice_lengths
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
|
@ -163,7 +163,6 @@ class DialectTest(MosaicGpuTest):
|
||||
barrier,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
@ -190,7 +189,6 @@ class DialectTest(MosaicGpuTest):
|
||||
barrier,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
@ -217,7 +215,6 @@ class DialectTest(MosaicGpuTest):
|
||||
barrier,
|
||||
indices,
|
||||
slice_lengths=[-2, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
@ -245,7 +242,6 @@ class DialectTest(MosaicGpuTest):
|
||||
barrier,
|
||||
indices,
|
||||
slice_lengths=[-1, 4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
@ -271,7 +267,6 @@ class DialectTest(MosaicGpuTest):
|
||||
barrier,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
@ -297,7 +292,6 @@ class DialectTest(MosaicGpuTest):
|
||||
barrier,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
@ -324,7 +318,6 @@ class DialectTest(MosaicGpuTest):
|
||||
barrier,
|
||||
indices,
|
||||
slice_lengths=[4],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([
|
||||
ir.IntegerAttr.get(i32, mgpu.dialect.Dimension.x),
|
||||
ir.IntegerAttr.get(i32, mgpu.dialect.Dimension.x),
|
||||
@ -356,7 +349,6 @@ class DialectTest(MosaicGpuTest):
|
||||
destination,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
@ -380,7 +372,6 @@ class DialectTest(MosaicGpuTest):
|
||||
destination,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
@ -404,7 +395,6 @@ class DialectTest(MosaicGpuTest):
|
||||
destination,
|
||||
indices,
|
||||
slice_lengths=[-2, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
@ -429,7 +419,6 @@ class DialectTest(MosaicGpuTest):
|
||||
destination,
|
||||
indices,
|
||||
slice_lengths=[-1, 4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
@ -452,7 +441,6 @@ class DialectTest(MosaicGpuTest):
|
||||
destination,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
@ -475,7 +463,6 @@ class DialectTest(MosaicGpuTest):
|
||||
destination,
|
||||
indices,
|
||||
slice_lengths=[4, 8],
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -2254,6 +2254,35 @@ class LayoutTest(TestCase):
|
||||
np.testing.assert_array_equal(f(x), x)
|
||||
|
||||
|
||||
def with_swizzle(
|
||||
mem_ref: ir.Value, swizzle: mgpu_dialect.SwizzlingMode
|
||||
) -> ir.Value:
|
||||
"""Appends a swizzle transform to the layout of the given memref.
|
||||
|
||||
If the memref's layout is not a LayoutAttr, it is replaced with a LayoutAttr
|
||||
with a single swizzle transform.
|
||||
"""
|
||||
mem_ref_type = ir.MemRefType(mem_ref.type)
|
||||
old_layout = (
|
||||
mgpu_dialect.LayoutAttr(mem_ref_type.layout)
|
||||
if mgpu_dialect.LayoutAttr.isinstance(mem_ref_type.layout)
|
||||
else mgpu_dialect.LayoutAttr.get(mem_ref_type.rank, [])
|
||||
)
|
||||
|
||||
swizzle_transform = mgpu_dialect.SwizzleTransformAttr.get(swizzle)
|
||||
new_layout = mgpu_dialect.LayoutAttr.get(
|
||||
mem_ref_type.rank, old_layout.transforms + [swizzle_transform]
|
||||
)
|
||||
|
||||
memref_swizzle_type = ir.MemRefType.get(
|
||||
mem_ref_type.shape,
|
||||
mem_ref_type.element_type,
|
||||
new_layout,
|
||||
mem_ref_type.memory_space,
|
||||
)
|
||||
return memref.cast(memref_swizzle_type, mem_ref)
|
||||
|
||||
|
||||
class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
"""Device tests with lowering from the MLIR dialect and layout inference."""
|
||||
|
||||
@ -2328,23 +2357,19 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
# GMEM -> SMEM
|
||||
mgpu_dialect.async_load(
|
||||
source=a_gmem_ref,
|
||||
destination=a_smem_ref,
|
||||
destination=with_swizzle(a_smem_ref, swizzle),
|
||||
barrier=dialect_barrier,
|
||||
indices=[zero_i32, zero_i32],
|
||||
slice_lengths=shape,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
swizzle=swizzle,
|
||||
)
|
||||
mgpu_dialect.async_load(
|
||||
source=b_gmem_ref,
|
||||
destination=b_smem_ref,
|
||||
destination=with_swizzle(b_smem_ref, swizzle),
|
||||
barrier=dialect_barrier,
|
||||
indices=[zero_i32, zero_i32],
|
||||
slice_lengths=shape,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
swizzle=swizzle,
|
||||
)
|
||||
|
||||
parities = memref.load(tma_barrier.phases, [])
|
||||
@ -2366,12 +2391,10 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
|
||||
# SMEM -> GMEM
|
||||
mgpu_dialect.async_store(
|
||||
source=result_smem_ref,
|
||||
source=with_swizzle(result_smem_ref, swizzle),
|
||||
destination=result_gmem_ref,
|
||||
indices=[zero_i32, zero_i32],
|
||||
slice_lengths=shape,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
swizzle=swizzle,
|
||||
)
|
||||
nvvm.cp_async_bulk_wait_group(0)
|
||||
utils.warpgroup_barrier()
|
||||
@ -2437,23 +2460,19 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
# GMEM -> SMEM
|
||||
mgpu_dialect.async_load(
|
||||
source=a_gmem_ref,
|
||||
destination=a_smem_ref,
|
||||
destination=with_swizzle(a_smem_ref, swizzle),
|
||||
barrier=dialect_barrier,
|
||||
indices=[zero_i32, zero_i32, zero_i32, zero_i32],
|
||||
slice_lengths=shape_a,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
swizzle=swizzle,
|
||||
)
|
||||
mgpu_dialect.async_load(
|
||||
source=b_gmem_ref,
|
||||
destination=b_smem_ref,
|
||||
destination=with_swizzle(b_smem_ref, swizzle),
|
||||
barrier=dialect_barrier,
|
||||
indices=[zero_i32, zero_i32, zero_i32, zero_i32],
|
||||
slice_lengths=shape_b,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
collective=ir.ArrayAttr.get([]),
|
||||
swizzle=swizzle,
|
||||
)
|
||||
|
||||
parities = memref.load(tma_barrier.phases, [])
|
||||
@ -2486,8 +2505,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
|
||||
destination=result_gmem_ref,
|
||||
indices=[zero_i32, zero_i32],
|
||||
slice_lengths=shape_result,
|
||||
transforms=ir.ArrayAttr.get([]),
|
||||
swizzle=mgpu_dialect.SwizzlingMode.kNoSwizzle,
|
||||
)
|
||||
nvvm.cp_async_bulk_wait_group(0)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user