[Mosaic GPU] Use the memref layout to encode transforms (only swizzle for now).

Tile and Transpose transforms to follow.

PiperOrigin-RevId: 725716812
This commit is contained in:
Dimitar (Mitko) Asenov 2025-02-11 11:50:46 -08:00 committed by jax authors
parent c2bd1576da
commit 6fc1c61520
9 changed files with 287 additions and 54 deletions

View File

@ -240,12 +240,33 @@ def _vector_splat_op_lowering_rule(
return [_fragmented_array_to_ir(fragmented_array, out_vec_ty)]
def layout_to_swizzle(layout: ir.Attribute) -> mgpu.SwizzlingMode:
"""Returns the swizzle mode for the given layout.
If the layout is not a LayoutAttr, the swizzle is kNoSwizzle. Otherwise,
the layout must consist of exactly one swizzle transform.
"""
if mgpu.LayoutAttr.isinstance(layout):
transforms = mgpu.LayoutAttr(layout).transforms
if len(transforms) != 1:
raise ValueError(f"{layout} has multiple transforms")
if not mgpu.SwizzleTransformAttr.isinstance(transforms[0]):
raise NotImplementedError("Only siwzzle transforms are supported.")
# TODO(dasenov): Swizzling can change if the ref is sliced in certain
# ways. We might want to enforce some restrictions here.
return mgpu.SwizzleTransformAttr(transforms[0]).swizzle
return mgpu.SwizzlingMode.kNoSwizzle
@_register_lowering(mgpu.AsyncLoadOp)
def _mgpu_async_load_op_lowering_rule(
ctx: LoweringContext, load_op: mgpu.AsyncLoadOp
) -> Sequence[ir.Value]:
assert ctx.launch_context is not None
barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier)
dst_layout = ir.MemRefType(load_op.destination.type).layout
# TODO(dasenov): Add support for the remaining op properties.
ctx.launch_context.async_copy(
src_ref=load_op.source,
@ -253,7 +274,7 @@ def _mgpu_async_load_op_lowering_rule(
barrier=barrier,
arrive=False,
uniform=True,
swizzle=load_op.swizzle.value,
swizzle=layout_to_swizzle(dst_layout),
predicate=ctx.single_thread_per_warpgroup_predicate,
)
return []
@ -264,11 +285,13 @@ def _mgpu_async_store_op_lowering_rule(
ctx: LoweringContext, store_op: mgpu.AsyncStoreOp
) -> Sequence[ir.Value]:
assert ctx.launch_context is not None
src_layout = ir.MemRefType(store_op.source.type).layout
# TODO(dasenov): Add support for the remaining op properties.
ctx.launch_context.async_copy(
src_ref=store_op.source,
dst_ref=store_op.destination,
swizzle=store_op.swizzle.value,
swizzle=layout_to_swizzle(src_layout),
uniform=True,
predicate=ctx.single_thread_per_warpgroup_predicate,
)

View File

@ -13,9 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <vector>
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep
#include "nanobind/nanobind.h"
#include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h"
#include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h"
namespace nb = nanobind;
@ -31,4 +35,39 @@ NB_MODULE(_mosaic_gpu_ext, m) {
}
},
nb::arg("context"), nb::arg("load") = true);
mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "SwizzleTransformAttr", MosaicGpuIsASwizzleTransformAttr)
.def_classmethod(
"get",
[](nb::object cls, int32_t swizzle, MlirContext ctx) {
return cls(MosaicGpuSwizzleTransformAttrGet(
ctx, static_cast<int32_t>(swizzle)));
},
nb::arg("cls"), nb::arg("swizzle"),
nb::arg("context").none() = nb::none(),
"Creates a SwizzleTransformAttr with the given swizzle.")
.def_property_readonly("swizzle", [](MlirAttribute self) {
return MosaicGpuSwizzleTransformAttrGetSwizzle(self);
});
mlir::python::nanobind_adaptors::mlir_attribute_subclass(
m, "LayoutAttr", MosaicGpuIsALayoutAttr)
.def_classmethod(
"get",
[](nb::object cls, int32_t num_dimensions,
std::vector<MlirAttribute>& transforms, MlirContext ctx) {
return cls(MosaicGpuLayoutAttrGet(
ctx, num_dimensions, transforms.data(), transforms.size()));
},
nb::arg("cls"), nb::arg("num_dimensions"), nb::arg("transforms"),
nb::arg("context").none() = nb::none(),
"Creates a LayoutAttr with the given transforms.")
.def_property_readonly("transforms", [](MlirAttribute self) {
std::vector<MlirAttribute> result;
for (int i = 0; i < MosaicGpuLayoutAttrGetTransformsSize(self); ++i) {
result.push_back(MosaicGpuLayoutAttrGetTransform(self, i));
}
return result;
});
}

View File

@ -198,10 +198,12 @@ genrule(
)
DIALECT_CAPI_SOURCES = [
":integrations/c/attributes.cc",
":integrations/c/gpu_dialect.cc",
]
DIALECT_CAPI_HEADERS = [
":integrations/c/attributes.h",
":integrations/c/gpu_dialect.h",
]
@ -212,7 +214,10 @@ cc_library(
deps = [
":mosaic_gpu",
":mosaic_gpu_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
@ -234,7 +239,10 @@ cc_library(
deps = [
":mosaic_gpu",
":mosaic_gpu_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIRObjects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
alwayslink = True,
)

View File

@ -0,0 +1,62 @@
#include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h"
#include <cstdint>
#include <vector>
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/IR/Attributes.h"
#include "mlir/Support/LLVM.h"
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h"
//===----------------------------------------------------------------------===//
// SwizzleTransformAttr
//===----------------------------------------------------------------------===//
bool MosaicGpuIsASwizzleTransformAttr(MlirAttribute attr) {
return mlir::isa<mosaic_gpu::SwizzleTransformAttr>(unwrap(attr));
}
MlirAttribute MosaicGpuSwizzleTransformAttrGet(MlirContext ctx,
int32_t swizzle) {
return wrap(mosaic_gpu::SwizzleTransformAttr::get(
unwrap(ctx),
mosaic_gpu::SwizzlingModeAttr::get(
unwrap(ctx), static_cast<mosaic_gpu::SwizzlingMode>(swizzle))));
}
int32_t MosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr) {
return static_cast<int32_t>(
mlir::cast<mosaic_gpu::SwizzleTransformAttr>(unwrap(attr))
.getSwizzle()
.getValue());
}
//===----------------------------------------------------------------------===//
// LayoutAttr
//===----------------------------------------------------------------------===//
bool MosaicGpuIsALayoutAttr(MlirAttribute attr) {
return mlir::isa<mosaic_gpu::LayoutAttr>(unwrap(attr));
}
MlirAttribute MosaicGpuLayoutAttrGet(MlirContext ctx, int32_t num_dimensions,
MlirAttribute* transforms,
int32_t transforms_size) {
std::vector<mlir::Attribute> unwrapped_transforms;
unwrapped_transforms.reserve(transforms_size);
for (int i = 0; i < transforms_size; ++i) {
unwrapped_transforms.push_back(unwrap(transforms[i]));
}
return wrap(mosaic_gpu::LayoutAttr::get(unwrap(ctx), num_dimensions,
unwrapped_transforms));
}
int32_t MosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr) {
return mlir::cast<mosaic_gpu::LayoutAttr>(unwrap(attr))
.getTransforms()
.size();
}
MlirAttribute MosaicGpuLayoutAttrGetTransform(MlirAttribute attr,
int32_t index) {
return wrap(
mlir::cast<mosaic_gpu::LayoutAttr>(unwrap(attr)).getTransforms()[index]);
}

View File

@ -0,0 +1,60 @@
/* Copyright 2025 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_ATTRIBUTES_H_
#define JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_ATTRIBUTES_H_
#include <stdint.h>
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
extern "C" {
#endif
//===----------------------------------------------------------------------===//
// SwizzleTransformAttr
//===----------------------------------------------------------------------===//
MLIR_CAPI_EXPORTED bool MosaicGpuIsASwizzleTransformAttr(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute
MosaicGpuSwizzleTransformAttrGet(MlirContext ctx, int32_t swizzle);
MLIR_CAPI_EXPORTED int32_t
MosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr);
//===----------------------------------------------------------------------===//
// LayoutAttr
//===----------------------------------------------------------------------===//
MLIR_CAPI_EXPORTED bool MosaicGpuIsALayoutAttr(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute
MosaicGpuLayoutAttrGet(MlirContext ctx, int32_t num_dimensions,
MlirAttribute* transforms, int32_t transforms_size);
MLIR_CAPI_EXPORTED int32_t
MosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute
MosaicGpuLayoutAttrGetTransform(MlirAttribute attr, int32_t index);
#ifdef __cplusplus
}
#endif
#endif // JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_ATTRIBUTES_H_

View File

@ -485,6 +485,14 @@ llvm::LogicalResult WGMMAOp::verify() {
return llvm::success();
}
mlir::AffineMap LayoutAttr::getAffineMap() const {
// This always returns an identity map. It's technically not correct, but we
// don't actually use it anywhere. It's only called during verification of the
// layout attribute and needs to be semi-valid.
return mlir::AffineMap::getMultiDimIdentityMap(getNumDimensions(),
getContext());
}
void MosaicGPUDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST

View File

@ -20,6 +20,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/CommonTypeConstraints.td"
@ -39,8 +40,8 @@ class MosaicGPU_Type<string name, string mnemonic_, list<Trait> traits = []>
let mnemonic = mnemonic_;
}
class MosaicGPU_Attr<string name, string mnemonic_>
: AttrDef<MosaicGPU_Dialect, name> {
class MosaicGPU_Attr<string name, string mnemonic_, list<Trait> traits = []>
: AttrDef<MosaicGPU_Dialect, name, traits> {
let mnemonic = mnemonic_;
}
@ -194,7 +195,7 @@ def MosaicGPU_SwizzlingMode : I32EnumAttr<"SwizzlingMode",
def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> {
let parameters = (ins Variadic<I64>:$tiling);
let summary = "Tiles a suffix of memref dimensions.";
let summary = "Specifies a transform that tiles suffix dimensions of a memref in SMEM.";
let description = [{
For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32),
the shape of the result will be (5, 2, 4, 64, 32). The shape always ends
@ -210,10 +211,38 @@ def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> {
def TransposeTransformAttr : MosaicGPU_Attr<"TransposeTransform", "transpose"> {
let parameters = (ins Variadic<I64>:$permutation);
let summary = "Specifies how to transpose a memref.";
let summary = "Specifies a transpose transform of a memref in SMEM.";
let assemblyFormat = "`<` $permutation `>`";
}
def SwizzleTransformAttr : MosaicGPU_Attr<"SwizzleTransform", "swizzle"> {
let parameters = (ins "SwizzlingModeAttr":$swizzle);
let summary = "Specifies a swizzle transform of a memref in SMEM.";
let assemblyFormat = "`<` $swizzle `>`";
}
def LayoutAttr : MosaicGPU_Attr<"Layout", "layout",
[DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface>]> {
let parameters = (ins
TypeParameter<"int32_t", "number of dimensions">:$num_dimensions,
ArrayRefParameter<"mlir::Attribute", "transforms">:$transforms
);
let summary = "Specifies a layout of a memref in SMEM.";
let description = [{
This layout attribute is used to specify the layout of a memref in SMEM.
It is composed of a number of transforms, which are applied in the order
they are provided. The transforms can be any combination of:
- TileTransformAttr
- TransposeTransformAttr
- SwizzleTransformAttr
The num_dimensions parameter must match the rank of the memref shape.
}];
let assemblyFormat = "`<` $num_dimensions `,` $transforms `>`";
}
def GlobalMemory : Resource<"::mosaic_gpu::GlobalMemory">;
def MosaicGPU_AsyncLoadOp : Op<MosaicGPU_Dialect, "async_load",
@ -235,14 +264,16 @@ def MosaicGPU_AsyncLoadOp : Op<MosaicGPU_Dialect, "async_load",
indicates that the slice length is 1 and that the corresponding dimension
should be collapsed and does not appear in the `destination` MemRef.
Additional `transforms` may be provided to control how the `source` data is
mapped to the `destination`. The transformations will be composed in the
order they are provided. The `swizzle` attribute controls what swizzling
is applied to the data after it is transformed, before it is finally written
to SMEM. The transformed data is written in row-major order to the
contiguous SMEM `destination`. The untransformed `source` data does not need
to be contiguous, except for the last dimension, which needs to be
contiguous and the minor-most dimension.
Setting the `layout` attribute of the `destination` MemRef to an instance of
`LayoutAttr` enables additional `transforms` that control how the `source`
data is mapped to the `destination`. The transformations will be composed in
the order they are provided. The transforms may contain up to a single
swizzle transform that controls what swizzling is applied to the data after
it is transformed, before it is finally written to SMEM. The transformed
data is written in row-major order to the contiguous SMEM `destination`.
The untransformed `source` data does not need to be contiguous, except for
the last dimension, which needs to be contiguous and the minor-most
dimension.
The `collective` attribute can be provided to use TMA multicast to more
efficiently load the GMEM data in cases where multiple thread blocks are
@ -266,8 +297,6 @@ def MosaicGPU_AsyncLoadOp : Op<MosaicGPU_Dialect, "async_load",
// Attributes
DenseI64ArrayAttr:$slice_lengths,
TypedArrayAttrBase<AnyAttrOf<[TileTransformAttr, TransposeTransformAttr]>, "transforms">:$transforms,
DefaultValuedAttr<MosaicGPU_SwizzlingMode, "SwizzlingMode::kNoSwizzle">:$swizzle,
TypedArrayAttrBase<MosaicGPU_Dimension, "dimensions">:$collective
);
@ -299,11 +328,13 @@ def MosaicGPU_AsyncStoreOp : Op<MosaicGPU_Dialect, "async_store",
indicates that this dimension is collapsed in the `source` and needs to be
expanded to a slice of size 1 in the `destination`.
Additional `transforms` may be provided to control how the `destination`
data in GMEM is mapped to the `source` data in SMEM. The transformations
will be composed in the order they are provided. The `swizzle` attribute
is the swizzling mode of the `source` data in SMEM. The `source` SMEM data
is contiguous and the transformed data is written to the `destination` GMEM
Setting the `layout` attribute of the `source` MemRef to an instance of
`LayoutAttr` enables additional `transforms` that control how the
`destination` data in GMEM is mapped to the `source` data in SMEM. The
transformations will be composed in the order they are provided. The
transforms may contain up to a single swizzle transform that is the
swizzling mode of the `source` data in SMEM. The `source` SMEM data is
contiguous and the transformed data is written to the `destination` GMEM
which does not need to be contiguous.
The `predicate` input should be set to `true` by a single thread in the
@ -318,9 +349,7 @@ def MosaicGPU_AsyncStoreOp : Op<MosaicGPU_Dialect, "async_store",
PtxPredicate:$predicate,
// Attributes
DenseI64ArrayAttr:$slice_lengths,
TypedArrayAttrBase<AnyAttrOf<[TileTransformAttr, TransposeTransformAttr]>, "transforms">:$transforms,
DefaultValuedAttr<MosaicGPU_SwizzlingMode, "SwizzlingMode::kNoSwizzle">:$swizzle
DenseI64ArrayAttr:$slice_lengths
);
let assemblyFormat = [{

View File

@ -163,7 +163,6 @@ class DialectTest(MosaicGpuTest):
barrier,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
)
)
@ -190,7 +189,6 @@ class DialectTest(MosaicGpuTest):
barrier,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
)
)
@ -217,7 +215,6 @@ class DialectTest(MosaicGpuTest):
barrier,
indices,
slice_lengths=[-2, 8],
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
)
)
@ -245,7 +242,6 @@ class DialectTest(MosaicGpuTest):
barrier,
indices,
slice_lengths=[-1, 4, 8],
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
)
)
@ -271,7 +267,6 @@ class DialectTest(MosaicGpuTest):
barrier,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
)
)
@ -297,7 +292,6 @@ class DialectTest(MosaicGpuTest):
barrier,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
)
)
@ -324,7 +318,6 @@ class DialectTest(MosaicGpuTest):
barrier,
indices,
slice_lengths=[4],
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([
ir.IntegerAttr.get(i32, mgpu.dialect.Dimension.x),
ir.IntegerAttr.get(i32, mgpu.dialect.Dimension.x),
@ -356,7 +349,6 @@ class DialectTest(MosaicGpuTest):
destination,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
)
)
@ -380,7 +372,6 @@ class DialectTest(MosaicGpuTest):
destination,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
)
)
@ -404,7 +395,6 @@ class DialectTest(MosaicGpuTest):
destination,
indices,
slice_lengths=[-2, 8],
transforms=ir.ArrayAttr.get([]),
)
)
@ -429,7 +419,6 @@ class DialectTest(MosaicGpuTest):
destination,
indices,
slice_lengths=[-1, 4, 8],
transforms=ir.ArrayAttr.get([]),
)
)
@ -452,7 +441,6 @@ class DialectTest(MosaicGpuTest):
destination,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
)
)
@ -475,7 +463,6 @@ class DialectTest(MosaicGpuTest):
destination,
indices,
slice_lengths=[4, 8],
transforms=ir.ArrayAttr.get([]),
)
)

View File

@ -2254,6 +2254,35 @@ class LayoutTest(TestCase):
np.testing.assert_array_equal(f(x), x)
def with_swizzle(
mem_ref: ir.Value, swizzle: mgpu_dialect.SwizzlingMode
) -> ir.Value:
"""Appends a swizzle transform to the layout of the given memref.
If the memref's layout is not a LayoutAttr, it is replaced with a LayoutAttr
with a single swizzle transform.
"""
mem_ref_type = ir.MemRefType(mem_ref.type)
old_layout = (
mgpu_dialect.LayoutAttr(mem_ref_type.layout)
if mgpu_dialect.LayoutAttr.isinstance(mem_ref_type.layout)
else mgpu_dialect.LayoutAttr.get(mem_ref_type.rank, [])
)
swizzle_transform = mgpu_dialect.SwizzleTransformAttr.get(swizzle)
new_layout = mgpu_dialect.LayoutAttr.get(
mem_ref_type.rank, old_layout.transforms + [swizzle_transform]
)
memref_swizzle_type = ir.MemRefType.get(
mem_ref_type.shape,
mem_ref_type.element_type,
new_layout,
mem_ref_type.memory_space,
)
return memref.cast(memref_swizzle_type, mem_ref)
class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
"""Device tests with lowering from the MLIR dialect and layout inference."""
@ -2328,23 +2357,19 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
# GMEM -> SMEM
mgpu_dialect.async_load(
source=a_gmem_ref,
destination=a_smem_ref,
destination=with_swizzle(a_smem_ref, swizzle),
barrier=dialect_barrier,
indices=[zero_i32, zero_i32],
slice_lengths=shape,
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
swizzle=swizzle,
)
mgpu_dialect.async_load(
source=b_gmem_ref,
destination=b_smem_ref,
destination=with_swizzle(b_smem_ref, swizzle),
barrier=dialect_barrier,
indices=[zero_i32, zero_i32],
slice_lengths=shape,
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
swizzle=swizzle,
)
parities = memref.load(tma_barrier.phases, [])
@ -2366,12 +2391,10 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
# SMEM -> GMEM
mgpu_dialect.async_store(
source=result_smem_ref,
source=with_swizzle(result_smem_ref, swizzle),
destination=result_gmem_ref,
indices=[zero_i32, zero_i32],
slice_lengths=shape,
transforms=ir.ArrayAttr.get([]),
swizzle=swizzle,
)
nvvm.cp_async_bulk_wait_group(0)
utils.warpgroup_barrier()
@ -2437,23 +2460,19 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
# GMEM -> SMEM
mgpu_dialect.async_load(
source=a_gmem_ref,
destination=a_smem_ref,
destination=with_swizzle(a_smem_ref, swizzle),
barrier=dialect_barrier,
indices=[zero_i32, zero_i32, zero_i32, zero_i32],
slice_lengths=shape_a,
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
swizzle=swizzle,
)
mgpu_dialect.async_load(
source=b_gmem_ref,
destination=b_smem_ref,
destination=with_swizzle(b_smem_ref, swizzle),
barrier=dialect_barrier,
indices=[zero_i32, zero_i32, zero_i32, zero_i32],
slice_lengths=shape_b,
transforms=ir.ArrayAttr.get([]),
collective=ir.ArrayAttr.get([]),
swizzle=swizzle,
)
parities = memref.load(tma_barrier.phases, [])
@ -2486,8 +2505,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
destination=result_gmem_ref,
indices=[zero_i32, zero_i32],
slice_lengths=shape_result,
transforms=ir.ArrayAttr.get([]),
swizzle=mgpu_dialect.SwizzlingMode.kNoSwizzle,
)
nvvm.cp_async_bulk_wait_group(0)