mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
[Mosaic GPU] Fuse slicing into s4 -> bf16 upcasts
This allows us to significantly simplify the generated PTX/SASS, which is currently cluttered with LLVM trying to align slices to start at bit 0 and failing to CSE the right shifts. PiperOrigin-RevId: 737967890
This commit is contained in:
parent
7a459f0ed1
commit
8da93249d2
@ -1373,15 +1373,30 @@ class FragmentedArray:
|
||||
for group_size in (8, 4, 2):
|
||||
int_ty = ir.IntegerType.get_signless(group_size * 4)
|
||||
while vector_len - offset >= group_size:
|
||||
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
|
||||
reg_slice_int = utils.bitcast(reg_slice, int_ty)
|
||||
if int_ty != i32:
|
||||
reg_slice_int = arith.extsi(i32, reg_slice_int)
|
||||
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
|
||||
out_int_regs.extend(
|
||||
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
|
||||
for part in range(group_size // 2)
|
||||
)
|
||||
# If the vector originates from a slice (common after relayouts), we
|
||||
# can fuse the slicing into the conversion and prevent LLVM from
|
||||
# generating a bunch of shifts to align the vector data to the LSB.
|
||||
# This also lets us share the right shift among more vectors.
|
||||
if (isinstance(slice_op := reg.owner.opview, vector.ExtractStridedSliceOp)
|
||||
and utils.bitwidth(slice_op.vector.type) == 32
|
||||
and slice_op.strides[0].value == 1):
|
||||
slice_offset = slice_op.offsets[0].value + offset
|
||||
reg_int = utils.bitcast(slice_op.vector, i32)
|
||||
reg_int_shr = arith.shrui(reg_int, c(4, i32))
|
||||
out_int_regs.extend(
|
||||
upcast_to_bf16(reg_int, reg_int_shr, part=(slice_offset // 2 + part))
|
||||
for part in range(group_size // 2)
|
||||
)
|
||||
else:
|
||||
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
|
||||
reg_slice_int = utils.bitcast(reg_slice, int_ty)
|
||||
if int_ty != i32:
|
||||
reg_slice_int = arith.extsi(i32, reg_slice_int)
|
||||
reg_slice_int_shr = arith.shrui(reg_slice_int, c(4, i32))
|
||||
out_int_regs.extend(
|
||||
upcast_to_bf16(reg_slice_int, reg_slice_int_shr, part=part)
|
||||
for part in range(group_size // 2)
|
||||
)
|
||||
offset += group_size
|
||||
assert offset == vector_len
|
||||
out_vec_int = utils.vector_concat([
|
||||
|
@ -346,7 +346,7 @@ def bitwidth_impl(ty: ir.Type):
|
||||
return ir.IntegerType(ty).width
|
||||
if ir.FloatType.isinstance(ty):
|
||||
return ir.FloatType(ty).width
|
||||
if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"):
|
||||
if dialect is not None and ty == ir.Type.parse("!mosaic_gpu.barrier"):
|
||||
return MBARRIER_BYTES * 8
|
||||
if ir.VectorType.isinstance(ty):
|
||||
vty = ir.VectorType(ty)
|
||||
@ -1237,17 +1237,15 @@ def ceil_div(x: int, y: int):
|
||||
|
||||
|
||||
def vector_slice(v: ir.Value, s: slice):
|
||||
i32 = ir.IntegerType.get_signless(32)
|
||||
v_ty = ir.VectorType(v.type)
|
||||
if len(v_ty.shape) != 1:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError(v_ty)
|
||||
[v_len] = v_ty.shape
|
||||
it = range(v_len)[s]
|
||||
result = llvm.mlir_undef(ir.VectorType.get((len(it),), v_ty.element_type))
|
||||
for tgt, src in enumerate(it):
|
||||
elem = llvm.extractelement(v, c(src, i32))
|
||||
result = llvm.insertelement(result, elem, c(tgt, i32))
|
||||
return result
|
||||
slice_length = len(range(v_len)[s])
|
||||
return vector.extract_strided_slice(
|
||||
ir.VectorType.get((slice_length,), v_ty.element_type),
|
||||
v, [s.start or 0], [slice_length], [1],
|
||||
)
|
||||
|
||||
|
||||
def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value:
|
||||
|
@ -65,6 +65,7 @@ cc_library(
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
"@llvm-project//mlir:VectorDialect",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/mosaic/gpu/passes.h"
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@ -23,6 +24,7 @@ limitations under the License.
|
||||
#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/include/mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/include/mlir/IR/SymbolTable.h"
|
||||
@ -36,6 +38,49 @@ namespace gpu {
|
||||
|
||||
namespace {
|
||||
|
||||
// Upstream MLIR does not implement an LLVM lowering pattern for this op.
|
||||
struct ConvertExtractStridedSlicePattern final
|
||||
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
mlir::LogicalResult matchAndRewrite(
|
||||
mlir::vector::ExtractStridedSliceOp op, OpAdaptor subst,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
auto vty = op.getSourceVectorType();
|
||||
if (vty.getRank() != 1) {
|
||||
return rewriter.notifyMatchFailure(op, "only 1-D vectors are supported");
|
||||
}
|
||||
int64_t size =
|
||||
(*op.getSizes().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
|
||||
if (size < 0) {
|
||||
return rewriter.notifyMatchFailure(op, "size is negative");
|
||||
}
|
||||
int64_t start =
|
||||
(*op.getOffsets().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
|
||||
int64_t stride =
|
||||
(*op.getStrides().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
|
||||
if (stride != 1) {
|
||||
return rewriter.notifyMatchFailure(op, "only stride 1 is supported");
|
||||
}
|
||||
if (start < 0 || start + size > vty.getShape()[0]) {
|
||||
return rewriter.notifyMatchFailure(op, "slice is out of bounds");
|
||||
}
|
||||
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(
|
||||
op.getLoc(), op.getResult().getType());
|
||||
for (int64_t i = 0; i < size; ++i) {
|
||||
result = rewriter.create<mlir::LLVM::InsertElementOp>(
|
||||
op.getLoc(), result,
|
||||
rewriter.create<mlir::LLVM::ExtractElementOp>(
|
||||
op.getLoc(), subst.getVector(),
|
||||
rewriter.create<mlir::LLVM::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(i + start))),
|
||||
rewriter.create<mlir::LLVM::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(i)));
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertGpuToLLVMPass
|
||||
: public jaxlib::mlir::Pass<ConvertGpuToLLVMPass, mlir::ModuleOp> {
|
||||
public:
|
||||
@ -58,6 +103,7 @@ class ConvertGpuToLLVMPass
|
||||
});
|
||||
auto symtab = mlir::SymbolTable(getOperation());
|
||||
mlir::populateGpuToLLVMConversionPatterns(converter, patterns, false);
|
||||
patterns.insert<ConvertExtractStridedSlicePattern>(&getContext());
|
||||
if (mlir::applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))
|
||||
.failed()) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user