diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index ded17d5d4..8b8fdaceb 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1373,15 +1373,30 @@ class FragmentedArray: for group_size in (8, 4, 2): int_ty = ir.IntegerType.get_signless(group_size * 4) while vector_len - offset >= group_size: - reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size)) - reg_slice_int = utils.bitcast(reg_slice, int_ty) - if int_ty != i32: - reg_slice_int = arith.extsi(i32, reg_slice_int) - reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32)) - out_int_regs.extend( - upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part) - for part in range(group_size // 2) - ) + # If the vector originates from a slice (common after relayouts), we + # can fuse the slicing into the conversion and prevent LLVM from + # generating a bunch of shifts to align the vector data to the LSB. + # This also lets us share the right shift among more vectors. + if (isinstance(slice_op := reg.owner.opview, vector.ExtractStridedSliceOp) + and utils.bitwidth(slice_op.vector.type) == 32 + and slice_op.strides[0].value == 1): + slice_offset = slice_op.offsets[0].value + offset + reg_int = utils.bitcast(slice_op.vector, i32) + reg_int_shr = arith.shrui(reg_int, c(4, i32)) + out_int_regs.extend( + upcast_to_bf16(reg_int, reg_int_shr, part=(slice_offset // 2 + part)) + for part in range(group_size // 2) + ) + else: + reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size)) + reg_slice_int = utils.bitcast(reg_slice, int_ty) + if int_ty != i32: + reg_slice_int = arith.extsi(i32, reg_slice_int) + reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32)) + out_int_regs.extend( + upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part) + for part in range(group_size // 2) + ) offset += group_size assert offset == vector_len out_vec_int = utils.vector_concat([ diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 91cb19746..28534cf40 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -346,7 +346,7 @@ def bitwidth_impl(ty: ir.Type): return ir.IntegerType(ty).width if ir.FloatType.isinstance(ty): return ir.FloatType(ty).width - if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"): + if dialect is not None and ty == ir.Type.parse("!mosaic_gpu.barrier"): return MBARRIER_BYTES * 8 if ir.VectorType.isinstance(ty): vty = ir.VectorType(ty) @@ -1237,17 +1237,15 @@ def ceil_div(x: int, y: int): def vector_slice(v: ir.Value, s: slice): - i32 = ir.IntegerType.get_signless(32) v_ty = ir.VectorType(v.type) if len(v_ty.shape) != 1: - raise NotImplementedError + raise NotImplementedError(v_ty) [v_len] = v_ty.shape - it = range(v_len)[s] - result = llvm.mlir_undef(ir.VectorType.get((len(it),), v_ty.element_type)) - for tgt, src in enumerate(it): - elem = llvm.extractelement(v, c(src, i32)) - result = llvm.insertelement(result, elem, c(tgt, i32)) - return result + slice_length = len(range(v_len)[s]) + return vector.extract_strided_slice( + ir.VectorType.get((slice_length,), v_ty.element_type), + v, [s.start or 0], [slice_length], [1], + ) def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value: diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 0686db098..9249ae256 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -65,6 +65,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:VectorDialect", ], ) diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc index cee34ddae..b8c3fbb74 100644 --- a/jaxlib/mosaic/gpu/passes.cc +++ b/jaxlib/mosaic/gpu/passes.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "jaxlib/mosaic/gpu/passes.h" +#include #include #include #include @@ -23,6 +24,7 @@ limitations under the License. #include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/include/mlir/IR/BuiltinAttributes.h" #include "mlir/include/mlir/IR/BuiltinOps.h" #include "mlir/include/mlir/IR/SymbolTable.h" @@ -36,6 +38,49 @@ namespace gpu { namespace { +// Upstream MLIR does not implement an LLVM lowering pattern for this op. +struct ConvertExtractStridedSlicePattern final + : public mlir::OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + mlir::LogicalResult matchAndRewrite( + mlir::vector::ExtractStridedSliceOp op, OpAdaptor subst, + mlir::ConversionPatternRewriter &rewriter) const override { + auto vty = op.getSourceVectorType(); + if (vty.getRank() != 1) { + return rewriter.notifyMatchFailure(op, "only 1-D vectors are supported"); + } + int64_t size = + (*op.getSizes().getAsRange().begin()).getSInt(); + if (size < 0) { + return rewriter.notifyMatchFailure(op, "size is negative"); + } + int64_t start = + (*op.getOffsets().getAsRange().begin()).getSInt(); + int64_t stride = + (*op.getStrides().getAsRange().begin()).getSInt(); + if (stride != 1) { + return rewriter.notifyMatchFailure(op, "only stride 1 is supported"); + } + if (start < 0 || start + size > vty.getShape()[0]) { + return rewriter.notifyMatchFailure(op, "slice is out of bounds"); + } + mlir::Value result = rewriter.create( + op.getLoc(), op.getResult().getType()); + for (int64_t i = 0; i < size; ++i) { + result = rewriter.create( + op.getLoc(), result, + rewriter.create( + op.getLoc(), subst.getVector(), + rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(i + start))), + rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(i))); + } + rewriter.replaceOp(op, result); + return mlir::success(); + } +}; + class ConvertGpuToLLVMPass : public jaxlib::mlir::Pass { public: @@ -58,6 +103,7 @@ class ConvertGpuToLLVMPass }); auto symtab = mlir::SymbolTable(getOperation()); mlir::populateGpuToLLVMConversionPatterns(converter, patterns, false); + patterns.insert(&getContext()); if (mlir::applyPartialConversion(getOperation(), target, std::move(patterns)) .failed()) {