mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Mosaic GPU] Fuse slicing into s4 -> bf16 upcasts
This allows us to significantly simplify the generated PTX/SASS, which is currently cluttered with LLVM trying to align slices to start at bit 0 and failing to CSE the right shifts. PiperOrigin-RevId: 737967890
This commit is contained in:
parent
7a459f0ed1
commit
8da93249d2
@ -1373,6 +1373,21 @@ class FragmentedArray:
|
|||||||
for group_size in (8, 4, 2):
|
for group_size in (8, 4, 2):
|
||||||
int_ty = ir.IntegerType.get_signless(group_size * 4)
|
int_ty = ir.IntegerType.get_signless(group_size * 4)
|
||||||
while vector_len - offset >= group_size:
|
while vector_len - offset >= group_size:
|
||||||
|
# If the vector originates from a slice (common after relayouts), we
|
||||||
|
# can fuse the slicing into the conversion and prevent LLVM from
|
||||||
|
# generating a bunch of shifts to align the vector data to the LSB.
|
||||||
|
# This also lets us share the right shift among more vectors.
|
||||||
|
if (isinstance(slice_op := reg.owner.opview, vector.ExtractStridedSliceOp)
|
||||||
|
and utils.bitwidth(slice_op.vector.type) == 32
|
||||||
|
and slice_op.strides[0].value == 1):
|
||||||
|
slice_offset = slice_op.offsets[0].value + offset
|
||||||
|
reg_int = utils.bitcast(slice_op.vector, i32)
|
||||||
|
reg_int_shr = arith.shrui(reg_int, c(4, i32))
|
||||||
|
out_int_regs.extend(
|
||||||
|
upcast_to_bf16(reg_int, reg_int_shr, part=(slice_offset // 2 + part))
|
||||||
|
for part in range(group_size // 2)
|
||||||
|
)
|
||||||
|
else:
|
||||||
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
|
reg_slice = utils.vector_slice(reg, slice(offset, offset + group_size))
|
||||||
reg_slice_int = utils.bitcast(reg_slice, int_ty)
|
reg_slice_int = utils.bitcast(reg_slice, int_ty)
|
||||||
if int_ty != i32:
|
if int_ty != i32:
|
||||||
|
@ -346,7 +346,7 @@ def bitwidth_impl(ty: ir.Type):
|
|||||||
return ir.IntegerType(ty).width
|
return ir.IntegerType(ty).width
|
||||||
if ir.FloatType.isinstance(ty):
|
if ir.FloatType.isinstance(ty):
|
||||||
return ir.FloatType(ty).width
|
return ir.FloatType(ty).width
|
||||||
if dialect is not None and ir.Type.parse("!mosaic_gpu.barrier"):
|
if dialect is not None and ty == ir.Type.parse("!mosaic_gpu.barrier"):
|
||||||
return MBARRIER_BYTES * 8
|
return MBARRIER_BYTES * 8
|
||||||
if ir.VectorType.isinstance(ty):
|
if ir.VectorType.isinstance(ty):
|
||||||
vty = ir.VectorType(ty)
|
vty = ir.VectorType(ty)
|
||||||
@ -1237,17 +1237,15 @@ def ceil_div(x: int, y: int):
|
|||||||
|
|
||||||
|
|
||||||
def vector_slice(v: ir.Value, s: slice):
|
def vector_slice(v: ir.Value, s: slice):
|
||||||
i32 = ir.IntegerType.get_signless(32)
|
|
||||||
v_ty = ir.VectorType(v.type)
|
v_ty = ir.VectorType(v.type)
|
||||||
if len(v_ty.shape) != 1:
|
if len(v_ty.shape) != 1:
|
||||||
raise NotImplementedError
|
raise NotImplementedError(v_ty)
|
||||||
[v_len] = v_ty.shape
|
[v_len] = v_ty.shape
|
||||||
it = range(v_len)[s]
|
slice_length = len(range(v_len)[s])
|
||||||
result = llvm.mlir_undef(ir.VectorType.get((len(it),), v_ty.element_type))
|
return vector.extract_strided_slice(
|
||||||
for tgt, src in enumerate(it):
|
ir.VectorType.get((slice_length,), v_ty.element_type),
|
||||||
elem = llvm.extractelement(v, c(src, i32))
|
v, [s.start or 0], [slice_length], [1],
|
||||||
result = llvm.insertelement(result, elem, c(tgt, i32))
|
)
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value:
|
def vector_concat(vectors: Sequence[ir.Value]) -> ir.Value:
|
||||||
|
@ -65,6 +65,7 @@ cc_library(
|
|||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
"@llvm-project//mlir:Support",
|
"@llvm-project//mlir:Support",
|
||||||
"@llvm-project//mlir:TransformUtils",
|
"@llvm-project//mlir:TransformUtils",
|
||||||
|
"@llvm-project//mlir:VectorDialect",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "jaxlib/mosaic/gpu/passes.h"
|
#include "jaxlib/mosaic/gpu/passes.h"
|
||||||
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -23,6 +24,7 @@ limitations under the License.
|
|||||||
#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h"
|
#include "mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||||
#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h"
|
#include "mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||||
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
|
#include "mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
|
||||||
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
|
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/include/mlir/IR/BuiltinOps.h"
|
#include "mlir/include/mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/include/mlir/IR/SymbolTable.h"
|
#include "mlir/include/mlir/IR/SymbolTable.h"
|
||||||
@ -36,6 +38,49 @@ namespace gpu {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Upstream MLIR does not implement an LLVM lowering pattern for this op.
|
||||||
|
struct ConvertExtractStridedSlicePattern final
|
||||||
|
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
mlir::LogicalResult matchAndRewrite(
|
||||||
|
mlir::vector::ExtractStridedSliceOp op, OpAdaptor subst,
|
||||||
|
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto vty = op.getSourceVectorType();
|
||||||
|
if (vty.getRank() != 1) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "only 1-D vectors are supported");
|
||||||
|
}
|
||||||
|
int64_t size =
|
||||||
|
(*op.getSizes().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
|
||||||
|
if (size < 0) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "size is negative");
|
||||||
|
}
|
||||||
|
int64_t start =
|
||||||
|
(*op.getOffsets().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
|
||||||
|
int64_t stride =
|
||||||
|
(*op.getStrides().getAsRange<mlir::IntegerAttr>().begin()).getSInt();
|
||||||
|
if (stride != 1) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "only stride 1 is supported");
|
||||||
|
}
|
||||||
|
if (start < 0 || start + size > vty.getShape()[0]) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "slice is out of bounds");
|
||||||
|
}
|
||||||
|
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(
|
||||||
|
op.getLoc(), op.getResult().getType());
|
||||||
|
for (int64_t i = 0; i < size; ++i) {
|
||||||
|
result = rewriter.create<mlir::LLVM::InsertElementOp>(
|
||||||
|
op.getLoc(), result,
|
||||||
|
rewriter.create<mlir::LLVM::ExtractElementOp>(
|
||||||
|
op.getLoc(), subst.getVector(),
|
||||||
|
rewriter.create<mlir::LLVM::ConstantOp>(
|
||||||
|
op.getLoc(), rewriter.getI32IntegerAttr(i + start))),
|
||||||
|
rewriter.create<mlir::LLVM::ConstantOp>(
|
||||||
|
op.getLoc(), rewriter.getI32IntegerAttr(i)));
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return mlir::success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class ConvertGpuToLLVMPass
|
class ConvertGpuToLLVMPass
|
||||||
: public jaxlib::mlir::Pass<ConvertGpuToLLVMPass, mlir::ModuleOp> {
|
: public jaxlib::mlir::Pass<ConvertGpuToLLVMPass, mlir::ModuleOp> {
|
||||||
public:
|
public:
|
||||||
@ -58,6 +103,7 @@ class ConvertGpuToLLVMPass
|
|||||||
});
|
});
|
||||||
auto symtab = mlir::SymbolTable(getOperation());
|
auto symtab = mlir::SymbolTable(getOperation());
|
||||||
mlir::populateGpuToLLVMConversionPatterns(converter, patterns, false);
|
mlir::populateGpuToLLVMConversionPatterns(converter, patterns, false);
|
||||||
|
patterns.insert<ConvertExtractStridedSlicePattern>(&getContext());
|
||||||
if (mlir::applyPartialConversion(getOperation(), target,
|
if (mlir::applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns))
|
std::move(patterns))
|
||||||
.failed()) {
|
.failed()) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user