[Mosaic GPU] Fuse slicing into s4 -> bf16 upcasts

This allows us to significantly simplify the generated PTX/SASS,
which is currently cluttered with LLVM trying to align slices to
start at bit 0 and failing to CSE the right shifts.

PiperOrigin-RevId: 737967890
This commit is contained in:
Adam Paszke 2025-03-18 05:37:52 -07:00 committed by jax authors
parent 7a459f0ed1
commit 8da93249d2
4 changed files with 78 additions and 18 deletions

View File

@ -1373,15 +1373,30 @@ class FragmentedArray:
for group_size in (8, 4, 2):
int_ty = ir.IntegerType.get_signless(group_size * 4)
while vector_len - offset >= group_size:
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
reg_slice_int = utils.bitcast(reg_slice, int_ty)
if int_ty != i32:
reg_slice_int = arith.extsi(i32, reg_slice_int)
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
out_int_regs.extend(
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
for part in range(group_size // 2)
)
# If the vector originates from a slice (common after relayouts), we
# can fuse the slicing into the conversion and prevent LLVM from
# generating a bunch of shifts to align the vector data to the LSB.
# This also lets us share the right shift among more vectors.
if (isinstance(slice_op := reg.owner.opview, vector.ExtractStridedSliceOp)
and utils.bitwidth(slice_op.vector.type) == 32
and slice_op.strides[0].value == 1):
slice_offset = slice_op.offsets[0].value + offset
reg_int = utils.bitcast(slice_op.vector, i32)
reg_int_shr = arith.shrui(reg_int, c(4, i32))
out_int_regs.extend(
upcast_to_bf16(reg_int, reg_int_shr, part=(slice_offset // 2 + part))
for part in range(group_size // 2)
)
else:
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
reg_slice_int = utils.bitcast(reg_slice, int_ty)
if int_ty != i32:
reg_slice_int = arith.extsi(i32, reg_slice_int)
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
out_int_regs.extend(
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
for part in range(group_size // 2)
)
offset += group_size
assert offset == vector_len
out_vec_int = utils.vector_concat([

View File

@ -346,7 +346,7 @@ def bitwidth_impl(ty: ir.Type):
return ir.IntegerType(ty).width
if ir.FloatType.isinstance(ty):
return ir.FloatType(ty).width
if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"):
if dialect is not None and ty == ir.Type.parse("!mosaic_gpu.barrier"):
return MBARRIER_BYTES * 8
if ir.VectorType.isinstance(ty):
vty = ir.VectorType(ty)
@ -1237,17 +1237,15 @@ def ceil_div(x: int, y: int):
def vector_slice(v: ir.Value, s: slice):
i32 = ir.IntegerType.get_signless(32)
v_ty = ir.VectorType(v.type)
if len(v_ty.shape) != 1:
raise NotImplementedError
raise NotImplementedError(v_ty)
[v_len] = v_ty.shape
it = range(v_len)[s]
result = llvm.mlir_undef(ir.VectorType.get((len(it),), v_ty.element_type))
for tgt, src in enumerate(it):
elem = llvm.extractelement(v, c(src, i32))
result = llvm.insertelement(result, elem, c(tgt, i32))
return result
slice_length = len(range(v_len)[s])
return vector.extract_strided_slice(
ir.VectorType.get((slice_length,), v_ty.element_type),
v, [s.start or 0], [slice_length], [1],
)
def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value:

View File

@ -65,6 +65,7 @@ cc_library(
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:VectorDialect",
],
)

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "jaxlib/mosaic/gpu/passes.h"
#include <cstdint>
#include <memory>
#include <utility>
#include <vector>
@ -23,6 +24,7 @@ limitations under the License.
#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include "mlir/include/mlir/IR/BuiltinOps.h"
#include "mlir/include/mlir/IR/SymbolTable.h"
@ -36,6 +38,49 @@ namespace gpu {
namespace {
// Upstream MLIR does not implement an LLVM lowering pattern for this op.
struct ConvertExtractStridedSlicePattern final
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(
mlir::vector::ExtractStridedSliceOp op, OpAdaptor subst,
mlir::ConversionPatternRewriter &rewriter) const override {
auto vty = op.getSourceVectorType();
if (vty.getRank() != 1) {
return rewriter.notifyMatchFailure(op, "only 1-D vectors are supported");
}
int64_t size =
(*op.getSizes().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
if (size < 0) {
return rewriter.notifyMatchFailure(op, "size is negative");
}
int64_t start =
(*op.getOffsets().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
int64_t stride =
(*op.getStrides().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
if (stride != 1) {
return rewriter.notifyMatchFailure(op, "only stride 1 is supported");
}
if (start < 0 || start + size > vty.getShape()[0]) {
return rewriter.notifyMatchFailure(op, "slice is out of bounds");
}
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(
op.getLoc(), op.getResult().getType());
for (int64_t i = 0; i < size; ++i) {
result = rewriter.create<mlir::LLVM::InsertElementOp>(
op.getLoc(), result,
rewriter.create<mlir::LLVM::ExtractElementOp>(
op.getLoc(), subst.getVector(),
rewriter.create<mlir::LLVM::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(i + start))),
rewriter.create<mlir::LLVM::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(i)));
}
rewriter.replaceOp(op, result);
return mlir::success();
}
};
class ConvertGpuToLLVMPass
: public jaxlib::mlir::Pass<ConvertGpuToLLVMPass, mlir::ModuleOp> {
public:
@ -58,6 +103,7 @@ class ConvertGpuToLLVMPass
});
auto symtab = mlir::SymbolTable(getOperation());
mlir::populateGpuToLLVMConversionPatterns(converter, patterns, false);
patterns.insert<ConvertExtractStridedSlicePattern>(&getContext());
if (mlir::applyPartialConversion(getOperation(), target,
std::move(patterns))
.failed()) {