diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 9dbbcad3a..a5b251867 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -240,12 +240,33 @@ def _vector_splat_op_lowering_rule( return [_fragmented_array_to_ir(fragmented_array, out_vec_ty)] +def layout_to_swizzle(layout: ir.Attribute) -> mgpu.SwizzlingMode: + """Returns the swizzle mode for the given layout. + + If the layout is not a LayoutAttr, the swizzle is kNoSwizzle. Otherwise, + the layout must consist of exactly one swizzle transform. + """ + if mgpu.LayoutAttr.isinstance(layout): + transforms = mgpu.LayoutAttr(layout).transforms + if len(transforms) != 1: + raise ValueError(f"{layout} has multiple transforms") + if not mgpu.SwizzleTransformAttr.isinstance(transforms[0]): + raise NotImplementedError("Only siwzzle transforms are supported.") + # TODO(dasenov): Swizzling can change if the ref is sliced in certain + # ways. We might want to enforce some restrictions here. + return mgpu.SwizzleTransformAttr(transforms[0]).swizzle + + return mgpu.SwizzlingMode.kNoSwizzle + + @_register_lowering(mgpu.AsyncLoadOp) def _mgpu_async_load_op_lowering_rule( ctx: LoweringContext, load_op: mgpu.AsyncLoadOp ) -> Sequence[ir.Value]: assert ctx.launch_context is not None barrier = utils.BarrierRef.from_dialect_barrier_memref(load_op.barrier) + + dst_layout = ir.MemRefType(load_op.destination.type).layout # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( src_ref=load_op.source, @@ -253,7 +274,7 @@ def _mgpu_async_load_op_lowering_rule( barrier=barrier, arrive=False, uniform=True, - swizzle=load_op.swizzle.value, + swizzle=layout_to_swizzle(dst_layout), predicate=ctx.single_thread_per_warpgroup_predicate, ) return [] @@ -264,11 +285,13 @@ def _mgpu_async_store_op_lowering_rule( ctx: LoweringContext, store_op: mgpu.AsyncStoreOp ) -> Sequence[ir.Value]: assert ctx.launch_context is not None + + src_layout = ir.MemRefType(store_op.source.type).layout # TODO(dasenov): Add support for the remaining op properties. ctx.launch_context.async_copy( src_ref=store_op.source, dst_ref=store_op.destination, - swizzle=store_op.swizzle.value, + swizzle=layout_to_swizzle(src_layout), uniform=True, predicate=ctx.single_thread_per_warpgroup_predicate, ) diff --git a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc index d3009be21..d0ce5e78f 100644 --- a/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/mosaic_gpu_ext.cc @@ -13,9 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "mlir-c/IR.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" // IWYU pragma: keep #include "nanobind/nanobind.h" +#include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h" #include "jaxlib/mosaic/dialect/gpu/integrations/c/gpu_dialect.h" namespace nb = nanobind; @@ -31,4 +35,39 @@ NB_MODULE(_mosaic_gpu_ext, m) { } }, nb::arg("context"), nb::arg("load") = true); + + mlir::python::nanobind_adaptors::mlir_attribute_subclass( + m, "SwizzleTransformAttr", MosaicGpuIsASwizzleTransformAttr) + .def_classmethod( + "get", + [](nb::object cls, int32_t swizzle, MlirContext ctx) { + return cls(MosaicGpuSwizzleTransformAttrGet( + ctx, static_cast(swizzle))); + }, + nb::arg("cls"), nb::arg("swizzle"), + nb::arg("context").none() = nb::none(), + "Creates a SwizzleTransformAttr with the given swizzle.") + .def_property_readonly("swizzle", [](MlirAttribute self) { + return MosaicGpuSwizzleTransformAttrGetSwizzle(self); + }); + + mlir::python::nanobind_adaptors::mlir_attribute_subclass( + m, "LayoutAttr", MosaicGpuIsALayoutAttr) + .def_classmethod( + "get", + [](nb::object cls, int32_t num_dimensions, + std::vector& transforms, MlirContext ctx) { + return cls(MosaicGpuLayoutAttrGet( + ctx, num_dimensions, transforms.data(), transforms.size())); + }, + nb::arg("cls"), nb::arg("num_dimensions"), nb::arg("transforms"), + nb::arg("context").none() = nb::none(), + "Creates a LayoutAttr with the given transforms.") + .def_property_readonly("transforms", [](MlirAttribute self) { + std::vector result; + for (int i = 0; i < MosaicGpuLayoutAttrGetTransformsSize(self); ++i) { + result.push_back(MosaicGpuLayoutAttrGetTransform(self, i)); + } + return result; + }); } diff --git a/jaxlib/mosaic/dialect/gpu/BUILD b/jaxlib/mosaic/dialect/gpu/BUILD index 50ea58104..e21c8756a 100644 --- a/jaxlib/mosaic/dialect/gpu/BUILD +++ b/jaxlib/mosaic/dialect/gpu/BUILD @@ -198,10 +198,12 @@ genrule( ) DIALECT_CAPI_SOURCES = [ + ":integrations/c/attributes.cc", ":integrations/c/gpu_dialect.cc", ] DIALECT_CAPI_HEADERS = [ + ":integrations/c/attributes.h", ":integrations/c/gpu_dialect.h", ] @@ -212,7 +214,10 @@ cc_library( deps = [ ":mosaic_gpu", ":mosaic_gpu_inc_gen", + "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -234,7 +239,10 @@ cc_library( deps = [ ":mosaic_gpu", ":mosaic_gpu_inc_gen", + "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIRObjects", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], alwayslink = True, ) diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc new file mode 100644 index 000000000..27152bead --- /dev/null +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.cc @@ -0,0 +1,62 @@ +#include "jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h" + +#include +#include + +#include "mlir-c/IR.h" +#include "mlir/CAPI/IR.h" +#include "mlir/IR/Attributes.h" +#include "mlir/Support/LLVM.h" +#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" + +//===----------------------------------------------------------------------===// +// SwizzleTransformAttr +//===----------------------------------------------------------------------===// +bool MosaicGpuIsASwizzleTransformAttr(MlirAttribute attr) { + return mlir::isa(unwrap(attr)); +} +MlirAttribute MosaicGpuSwizzleTransformAttrGet(MlirContext ctx, + int32_t swizzle) { + return wrap(mosaic_gpu::SwizzleTransformAttr::get( + unwrap(ctx), + mosaic_gpu::SwizzlingModeAttr::get( + unwrap(ctx), static_cast(swizzle)))); +} +int32_t MosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr) { + return static_cast( + mlir::cast(unwrap(attr)) + .getSwizzle() + .getValue()); +} + +//===----------------------------------------------------------------------===// +// LayoutAttr +//===----------------------------------------------------------------------===// + +bool MosaicGpuIsALayoutAttr(MlirAttribute attr) { + return mlir::isa(unwrap(attr)); +} + +MlirAttribute MosaicGpuLayoutAttrGet(MlirContext ctx, int32_t num_dimensions, + MlirAttribute* transforms, + int32_t transforms_size) { + std::vector unwrapped_transforms; + unwrapped_transforms.reserve(transforms_size); + for (int i = 0; i < transforms_size; ++i) { + unwrapped_transforms.push_back(unwrap(transforms[i])); + } + return wrap(mosaic_gpu::LayoutAttr::get(unwrap(ctx), num_dimensions, + unwrapped_transforms)); +} + +int32_t MosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr) { + return mlir::cast(unwrap(attr)) + .getTransforms() + .size(); +} + +MlirAttribute MosaicGpuLayoutAttrGetTransform(MlirAttribute attr, + int32_t index) { + return wrap( + mlir::cast(unwrap(attr)).getTransforms()[index]); +} \ No newline at end of file diff --git a/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h new file mode 100644 index 000000000..149f4c66c --- /dev/null +++ b/jaxlib/mosaic/dialect/gpu/integrations/c/attributes.h @@ -0,0 +1,60 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_ATTRIBUTES_H_ +#define JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_ATTRIBUTES_H_ + +#include + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +// SwizzleTransformAttr +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool MosaicGpuIsASwizzleTransformAttr(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirAttribute +MosaicGpuSwizzleTransformAttrGet(MlirContext ctx, int32_t swizzle); + +MLIR_CAPI_EXPORTED int32_t +MosaicGpuSwizzleTransformAttrGetSwizzle(MlirAttribute attr); + +//===----------------------------------------------------------------------===// +// LayoutAttr +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool MosaicGpuIsALayoutAttr(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirAttribute +MosaicGpuLayoutAttrGet(MlirContext ctx, int32_t num_dimensions, + MlirAttribute* transforms, int32_t transforms_size); + +MLIR_CAPI_EXPORTED int32_t +MosaicGpuLayoutAttrGetTransformsSize(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirAttribute +MosaicGpuLayoutAttrGetTransform(MlirAttribute attr, int32_t index); + +#ifdef __cplusplus +} +#endif + +#endif // JAXLIB_MOSAIC_DIALECT_GPU_INTEGRATIONS_C_ATTRIBUTES_H_ \ No newline at end of file diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc index fb3da9c1c..e1f61a1fe 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc @@ -485,6 +485,14 @@ llvm::LogicalResult WGMMAOp::verify() { return llvm::success(); } +mlir::AffineMap LayoutAttr::getAffineMap() const { + // This always returns an identity map. It's technically not correct, but we + // don't actually use it anywhere. It's only called during verification of the + // layout attribute and needs to be semi-valid. + return mlir::AffineMap::getMultiDimIdentityMap(getNumDimensions(), + getContext()); +} + void MosaicGPUDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST diff --git a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td index eaca0a723..8b9b5454e 100644 --- a/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td +++ b/jaxlib/mosaic/dialect/gpu/mosaic_gpu.td @@ -20,6 +20,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/BuiltinTypeInterfaces.td" include "mlir/IR/CommonAttrConstraints.td" include "mlir/IR/CommonTypeConstraints.td" @@ -39,8 +40,8 @@ class MosaicGPU_Type traits = []> let mnemonic = mnemonic_; } -class MosaicGPU_Attr - : AttrDef { +class MosaicGPU_Attr traits = []> + : AttrDef { let mnemonic = mnemonic_; } @@ -194,7 +195,7 @@ def MosaicGPU_SwizzlingMode : I32EnumAttr<"SwizzlingMode", def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> { let parameters = (ins Variadic:$tiling); - let summary = "Tiles a suffix of memref dimensions."; + let summary = "Specifies a transform that tiles suffix dimensions of a memref in SMEM."; let description = [{ For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32), the shape of the result will be (5, 2, 4, 64, 32). The shape always ends @@ -210,10 +211,38 @@ def TileTransformAttr : MosaicGPU_Attr<"TileTransform", "tile"> { def TransposeTransformAttr : MosaicGPU_Attr<"TransposeTransform", "transpose"> { let parameters = (ins Variadic:$permutation); - let summary = "Specifies how to transpose a memref."; + let summary = "Specifies a transpose transform of a memref in SMEM."; let assemblyFormat = "`<` $permutation `>`"; } +def SwizzleTransformAttr : MosaicGPU_Attr<"SwizzleTransform", "swizzle"> { + let parameters = (ins "SwizzlingModeAttr":$swizzle); + + let summary = "Specifies a swizzle transform of a memref in SMEM."; + let assemblyFormat = "`<` $swizzle `>`"; +} + +def LayoutAttr : MosaicGPU_Attr<"Layout", "layout", + [DeclareAttrInterfaceMethods]> { + let parameters = (ins + TypeParameter<"int32_t", "number of dimensions">:$num_dimensions, + ArrayRefParameter<"mlir::Attribute", "transforms">:$transforms + ); + + let summary = "Specifies a layout of a memref in SMEM."; + let description = [{ + This layout attribute is used to specify the layout of a memref in SMEM. + It is composed of a number of transforms, which are applied in the order + they are provided. The transforms can be any combination of: + - TileTransformAttr + - TransposeTransformAttr + - SwizzleTransformAttr + + The num_dimensions parameter must match the rank of the memref shape. + }]; + let assemblyFormat = "`<` $num_dimensions `,` $transforms `>`"; +} + def GlobalMemory : Resource<"::mosaic_gpu::GlobalMemory">; def MosaicGPU_AsyncLoadOp : Op, "transforms">:$transforms, - DefaultValuedAttr:$swizzle, TypedArrayAttrBase:$collective ); @@ -299,11 +328,13 @@ def MosaicGPU_AsyncStoreOp : Op, "transforms">:$transforms, - DefaultValuedAttr:$swizzle + DenseI64ArrayAttr:$slice_lengths ); let assemblyFormat = [{ diff --git a/tests/mosaic/gpu_dialect_test.py b/tests/mosaic/gpu_dialect_test.py index d4b854f84..5cfdd5a34 100644 --- a/tests/mosaic/gpu_dialect_test.py +++ b/tests/mosaic/gpu_dialect_test.py @@ -163,7 +163,6 @@ class DialectTest(MosaicGpuTest): barrier, indices, slice_lengths=[4, 8], - transforms=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]), ) ) @@ -190,7 +189,6 @@ class DialectTest(MosaicGpuTest): barrier, indices, slice_lengths=[4, 8], - transforms=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]), ) ) @@ -217,7 +215,6 @@ class DialectTest(MosaicGpuTest): barrier, indices, slice_lengths=[-2, 8], - transforms=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]), ) ) @@ -245,7 +242,6 @@ class DialectTest(MosaicGpuTest): barrier, indices, slice_lengths=[-1, 4, 8], - transforms=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]), ) ) @@ -271,7 +267,6 @@ class DialectTest(MosaicGpuTest): barrier, indices, slice_lengths=[4, 8], - transforms=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]), ) ) @@ -297,7 +292,6 @@ class DialectTest(MosaicGpuTest): barrier, indices, slice_lengths=[4, 8], - transforms=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]), ) ) @@ -324,7 +318,6 @@ class DialectTest(MosaicGpuTest): barrier, indices, slice_lengths=[4], - transforms=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([ ir.IntegerAttr.get(i32, mgpu.dialect.Dimension.x), ir.IntegerAttr.get(i32, mgpu.dialect.Dimension.x), @@ -356,7 +349,6 @@ class DialectTest(MosaicGpuTest): destination, indices, slice_lengths=[4, 8], - transforms=ir.ArrayAttr.get([]), ) ) @@ -380,7 +372,6 @@ class DialectTest(MosaicGpuTest): destination, indices, slice_lengths=[4, 8], - transforms=ir.ArrayAttr.get([]), ) ) @@ -404,7 +395,6 @@ class DialectTest(MosaicGpuTest): destination, indices, slice_lengths=[-2, 8], - transforms=ir.ArrayAttr.get([]), ) ) @@ -429,7 +419,6 @@ class DialectTest(MosaicGpuTest): destination, indices, slice_lengths=[-1, 4, 8], - transforms=ir.ArrayAttr.get([]), ) ) @@ -452,7 +441,6 @@ class DialectTest(MosaicGpuTest): destination, indices, slice_lengths=[4, 8], - transforms=ir.ArrayAttr.get([]), ) ) @@ -475,7 +463,6 @@ class DialectTest(MosaicGpuTest): destination, indices, slice_lengths=[4, 8], - transforms=ir.ArrayAttr.get([]), ) ) diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 1ab033a36..09ea1717f 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -2254,6 +2254,35 @@ class LayoutTest(TestCase): np.testing.assert_array_equal(f(x), x) +def with_swizzle( + mem_ref: ir.Value, swizzle: mgpu_dialect.SwizzlingMode +) -> ir.Value: + """Appends a swizzle transform to the layout of the given memref. + + If the memref's layout is not a LayoutAttr, it is replaced with a LayoutAttr + with a single swizzle transform. + """ + mem_ref_type = ir.MemRefType(mem_ref.type) + old_layout = ( + mgpu_dialect.LayoutAttr(mem_ref_type.layout) + if mgpu_dialect.LayoutAttr.isinstance(mem_ref_type.layout) + else mgpu_dialect.LayoutAttr.get(mem_ref_type.rank, []) + ) + + swizzle_transform = mgpu_dialect.SwizzleTransformAttr.get(swizzle) + new_layout = mgpu_dialect.LayoutAttr.get( + mem_ref_type.rank, old_layout.transforms + [swizzle_transform] + ) + + memref_swizzle_type = ir.MemRefType.get( + mem_ref_type.shape, + mem_ref_type.element_type, + new_layout, + mem_ref_type.memory_space, + ) + return memref.cast(memref_swizzle_type, mem_ref) + + class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): """Device tests with lowering from the MLIR dialect and layout inference.""" @@ -2328,23 +2357,19 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): # GMEM -> SMEM mgpu_dialect.async_load( source=a_gmem_ref, - destination=a_smem_ref, + destination=with_swizzle(a_smem_ref, swizzle), barrier=dialect_barrier, indices=[zero_i32, zero_i32], slice_lengths=shape, - transforms=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]), - swizzle=swizzle, ) mgpu_dialect.async_load( source=b_gmem_ref, - destination=b_smem_ref, + destination=with_swizzle(b_smem_ref, swizzle), barrier=dialect_barrier, indices=[zero_i32, zero_i32], slice_lengths=shape, - transforms=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]), - swizzle=swizzle, ) parities = memref.load(tma_barrier.phases, []) @@ -2366,12 +2391,10 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): # SMEM -> GMEM mgpu_dialect.async_store( - source=result_smem_ref, + source=with_swizzle(result_smem_ref, swizzle), destination=result_gmem_ref, indices=[zero_i32, zero_i32], slice_lengths=shape, - transforms=ir.ArrayAttr.get([]), - swizzle=swizzle, ) nvvm.cp_async_bulk_wait_group(0) utils.warpgroup_barrier() @@ -2437,23 +2460,19 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): # GMEM -> SMEM mgpu_dialect.async_load( source=a_gmem_ref, - destination=a_smem_ref, + destination=with_swizzle(a_smem_ref, swizzle), barrier=dialect_barrier, indices=[zero_i32, zero_i32, zero_i32, zero_i32], slice_lengths=shape_a, - transforms=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]), - swizzle=swizzle, ) mgpu_dialect.async_load( source=b_gmem_ref, - destination=b_smem_ref, + destination=with_swizzle(b_smem_ref, swizzle), barrier=dialect_barrier, indices=[zero_i32, zero_i32, zero_i32, zero_i32], slice_lengths=shape_b, - transforms=ir.ArrayAttr.get([]), collective=ir.ArrayAttr.get([]), - swizzle=swizzle, ) parities = memref.load(tma_barrier.phases, []) @@ -2486,8 +2505,6 @@ class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): destination=result_gmem_ref, indices=[zero_i32, zero_i32], slice_lengths=shape_result, - transforms=ir.ArrayAttr.get([]), - swizzle=mgpu_dialect.SwizzlingMode.kNoSwizzle, ) nvvm.cp_async_bulk_wait_group(0)