mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Use default tiling in scratch buffers if XLA enables it
PiperOrigin-RevId: 650493683
This commit is contained in:
parent
d0394bfee5
commit
0da9b69285
@ -286,6 +286,9 @@ def _lower_tpu_kernel(
|
||||
pipeline.run(module.operation)
|
||||
dump_mlir(module, "post-hlo-conversion")
|
||||
|
||||
# Note: we don't pass the TpuTilingFlags here, since we don't know the
|
||||
# tiling decisions made by the compiler / what flags are enabled at this
|
||||
# point, so we assume everything can be tiled up to default tiling.
|
||||
pipeline = [
|
||||
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
|
||||
]
|
||||
|
@ -679,11 +679,12 @@ def InferMemRefLayoutPass : Pass<"tpu-infer-memref-layout", "::mlir::func::FuncO
|
||||
"::mlir::func::FuncDialect",
|
||||
"::mlir::memref::MemRefDialect",
|
||||
];
|
||||
let constructor = "::mlir::tpu::createInferMemRefLayoutPass(-1)";
|
||||
let constructor = "::mlir::tpu::createInferMemRefLayoutPass()";
|
||||
let options = [
|
||||
// If hardware_generation is not set, the default value of -1 will crash on
|
||||
// runOnOperation.
|
||||
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
|
||||
Option<"tpu_tiling_flags", "tpu-tiling-flags", "::mlir::tpu::TpuTilingFlags", /*default=*/"::mlir::tpu::TpuTilingFlags{}", "">,
|
||||
];
|
||||
}
|
||||
|
||||
|
@ -48,10 +48,16 @@ class TPUDialect;
|
||||
namespace mlir {
|
||||
namespace tpu {
|
||||
|
||||
struct TpuTilingFlags {
|
||||
bool use_x16_large_second_minor = false;
|
||||
bool use_x8_large_second_minor = false;
|
||||
bool use_x4_large_second_minor = false;
|
||||
};
|
||||
|
||||
std::pair<bool, bool> mightCommunicateBetweenChips(Operation* op);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
|
||||
int hardware_generation = -1);
|
||||
int hardware_generation = -1, const TpuTilingFlags &tpu_tiling_flags = {});
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass();
|
||||
|
||||
|
@ -188,9 +188,12 @@ FailureOr<Value> getInternalScratch(RewriteContext &ctx, OpBuilder &builder,
|
||||
if (sublane_count > ctx.max_sublanes_in_scratch) {
|
||||
return failure();
|
||||
}
|
||||
// We can omit tpu_tiling_flags here because, for internal scratch, the
|
||||
// tiling does not matter (its shape is (N, 128)).
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
MemRefType scratch_ref_ty,
|
||||
inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation));
|
||||
inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation,
|
||||
/*tpu_tiling_flags=*/{}));
|
||||
return builder.create<tpu::GetInternalScratchOp>(loc, scratch_ref_ty)
|
||||
.getResult();
|
||||
}
|
||||
@ -503,11 +506,14 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
|
||||
return ctx.func.emitOpError(
|
||||
"Not implemented: function has scratch_operands");
|
||||
}
|
||||
// We can omit tpu_tiling_flags here since we invoke inferMemref only for
|
||||
// constant operands which are kernel parameters that will have their layouts
|
||||
// overridden before the pass pipeline runs anyway.
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
MemRefType arg_type,
|
||||
inferMemref(
|
||||
MemRefType::get(value_ty.getShape(), value_ty.getElementType()),
|
||||
ctx.hardware_generation));
|
||||
ctx.hardware_generation, /*tpu_tiling_flags=*/{}));
|
||||
const BlockArgument argument =
|
||||
entry_block.insertArgument(entry_block.getNumArguments() - 1, arg_type,
|
||||
UnknownLoc::get(ctx.getMLIRContext()));
|
||||
|
@ -56,17 +56,40 @@ SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
|
||||
// Arguments:
|
||||
// num_128s: A number of 128-element groups in the full operand.
|
||||
// hardware_generation: An integer indicating the target TPU generation.
|
||||
// tpu_tiling_flags: A struct of flags indicating which large tiling modes are
|
||||
// enabled by XLA for memrefs.
|
||||
// bitwidth: The bitwidth of the element type of the operand.
|
||||
int getTilingFactor(const int num_128s, const int hardware_generation,
|
||||
const TpuTilingFlags &tpu_tiling_flags,
|
||||
const int8_t bitwidth) {
|
||||
CHECK(llvm::isPowerOf2_32(bitwidth));
|
||||
CHECK_LE(4, bitwidth);
|
||||
CHECK_LE(bitwidth, 32);
|
||||
const int packing = 32 / bitwidth;
|
||||
const int min_tiling = (1 + (hardware_generation < 4)) * packing;
|
||||
const int max_tiling = 8;
|
||||
const int max_normal_tiling = 8;
|
||||
|
||||
const int large_tiling = [&] {
|
||||
if (bitwidth == 4 && tpu_tiling_flags.use_x4_large_second_minor) {
|
||||
return 64;
|
||||
}
|
||||
if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) {
|
||||
return 32;
|
||||
}
|
||||
if (bitwidth == 16 && tpu_tiling_flags.use_x16_large_second_minor) {
|
||||
return 16;
|
||||
}
|
||||
return 8;
|
||||
}();
|
||||
|
||||
// Use large tiling if our operand is tall enough to fit at least one full
|
||||
// tile.
|
||||
if (large_tiling <= num_128s) {
|
||||
return large_tiling;
|
||||
}
|
||||
|
||||
int tiling = min_tiling;
|
||||
while (tiling < std::min(num_128s, max_tiling)) {
|
||||
while (tiling < std::min(num_128s, max_normal_tiling)) {
|
||||
tiling *= 2;
|
||||
}
|
||||
return tiling;
|
||||
@ -74,6 +97,7 @@ int getTilingFactor(const int num_128s, const int hardware_generation,
|
||||
|
||||
FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
|
||||
const int hardware_generation,
|
||||
const TpuTilingFlags &tpu_tiling_flags,
|
||||
int64_t leading_tile_rows = 0) {
|
||||
if (auto tiled_layout_attr =
|
||||
dyn_cast<TiledLayoutAttr>(memref_ty.getLayout())) {
|
||||
@ -97,7 +121,7 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
|
||||
if (memref_ty.getRank() == 1) {
|
||||
const int64_t leading_tile =
|
||||
getTilingFactor(llvm::divideCeil(memref_ty.getShape().back(), 128),
|
||||
hardware_generation, bitwidth) *
|
||||
hardware_generation, tpu_tiling_flags, bitwidth) *
|
||||
128;
|
||||
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile})};
|
||||
if (bitwidth != 32) {
|
||||
@ -115,8 +139,8 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
|
||||
const ArrayRef<int64_t> shape = memref_ty.getShape();
|
||||
const int64_t second_minor = shape[shape.size() - 2];
|
||||
if (leading_tile_rows == 0) {
|
||||
leading_tile_rows =
|
||||
getTilingFactor(second_minor, hardware_generation, bitwidth);
|
||||
leading_tile_rows = getTilingFactor(second_minor, hardware_generation,
|
||||
tpu_tiling_flags, bitwidth);
|
||||
}
|
||||
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile_rows, 128})};
|
||||
if (bitwidth != 32) {
|
||||
@ -160,6 +184,7 @@ LogicalResult checkTiles(MLIRContext *mlir_ctx,
|
||||
|
||||
FailureOr<MemRefType> inferMemref(MemRefType memref,
|
||||
const int hardware_generation,
|
||||
const TpuTilingFlags &tpu_tiling_flags,
|
||||
int64_t leading_tile_rows) {
|
||||
if (isa<SemaphoreType, DMASemaphoreType>(memref.getElementType())) {
|
||||
const Attribute semaphore_mem = tpu::MemorySpaceAttr::get(
|
||||
@ -180,9 +205,9 @@ FailureOr<MemRefType> inferMemref(MemRefType memref,
|
||||
tpu::MemorySpaceAttr::get(memref.getContext(), MemorySpace::vmem);
|
||||
const Attribute memory_space =
|
||||
memref.getMemorySpace() == nullptr ? vmem : memref.getMemorySpace();
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const TiledLayoutAttr layout,
|
||||
inferLayout(memref, hardware_generation, leading_tile_rows));
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const TiledLayoutAttr layout,
|
||||
inferLayout(memref, hardware_generation,
|
||||
tpu_tiling_flags, leading_tile_rows));
|
||||
|
||||
const ArrayRef<xla::Tile> tiles = layout.getTiles();
|
||||
if (failed(checkTiles(memref.getContext(), tiles))) {
|
||||
@ -203,12 +228,14 @@ FailureOr<MemRefType> inferMemref(MemRefType memref,
|
||||
memory_space);
|
||||
}
|
||||
|
||||
LogicalResult inferOp(Operation &op, const int hardware_generation) {
|
||||
LogicalResult inferOp(Operation &op, const int hardware_generation,
|
||||
const TpuTilingFlags &tpu_tiling_flags) {
|
||||
if (auto alloca_op = dyn_cast<memref::AllocaOp>(op)) {
|
||||
TypedValue<MemRefType> arg = alloca_op.getResult();
|
||||
const MemRefType memref_ty = alloca_op.getResult().getType();
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty,
|
||||
inferMemref(memref_ty, hardware_generation));
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const MemRefType new_memref_ty,
|
||||
inferMemref(memref_ty, hardware_generation, tpu_tiling_flags));
|
||||
alloca_op.getResult().setType(new_memref_ty);
|
||||
if (memref_ty != new_memref_ty) {
|
||||
OpBuilder builder(alloca_op->getContext());
|
||||
@ -223,8 +250,9 @@ LogicalResult inferOp(Operation &op, const int hardware_generation) {
|
||||
} else if (auto alloca_op = dyn_cast<tpu::AllocaSemaphoreOp>(op)) {
|
||||
TypedValue<MemRefType> arg = alloca_op.getResult();
|
||||
const MemRefType memref_ty = alloca_op.getResult().getType();
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty,
|
||||
inferMemref(memref_ty, hardware_generation));
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const MemRefType new_memref_ty,
|
||||
inferMemref(memref_ty, hardware_generation, tpu_tiling_flags));
|
||||
alloca_op.getResult().setType(new_memref_ty);
|
||||
if (memref_ty != new_memref_ty) {
|
||||
OpBuilder builder(alloca_op->getContext());
|
||||
@ -240,7 +268,7 @@ LogicalResult inferOp(Operation &op, const int hardware_generation) {
|
||||
for (Region ®ion : op.getRegions()) {
|
||||
for (Block &block : region) {
|
||||
for (Operation& op : block) {
|
||||
if (failed(inferOp(op, hardware_generation))) {
|
||||
if (failed(inferOp(op, hardware_generation, tpu_tiling_flags))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
@ -249,7 +277,8 @@ LogicalResult inferOp(Operation &op, const int hardware_generation) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult inferFunc(func::FuncOp f, const int hardware_generation) {
|
||||
LogicalResult inferFunc(func::FuncOp f, const int hardware_generation,
|
||||
const TpuTilingFlags &tpu_tiling_flags) {
|
||||
if (!f.getBody().hasOneBlock()) {
|
||||
return f.emitOpError("Functions should only have a single block");
|
||||
}
|
||||
@ -273,7 +302,8 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation) {
|
||||
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const MemRefType new_memref_ty,
|
||||
inferMemref(memref_ty, hardware_generation, leading_tile_rows));
|
||||
inferMemref(memref_ty, hardware_generation, tpu_tiling_flags,
|
||||
leading_tile_rows));
|
||||
arg.setType(new_memref_ty);
|
||||
new_arg_types.push_back(arg.getType());
|
||||
if (memref_ty != new_memref_ty) {
|
||||
@ -292,7 +322,7 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation) {
|
||||
f.setFunctionType(
|
||||
builder.getAttr<FunctionType>(new_arg_types, f.getResultTypes()));
|
||||
for (Operation &op : entry.getOperations()) {
|
||||
if (failed(inferOp(op, hardware_generation))) {
|
||||
if (failed(inferOp(op, hardware_generation, tpu_tiling_flags))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
@ -307,14 +337,15 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation) {
|
||||
// Arguments:
|
||||
// module: The MLIR module on which to perform the inference.
|
||||
// hardware_generation: The TPU hardware generation to target.
|
||||
LogicalResult inferModule(ModuleOp module, const int hardware_generation) {
|
||||
LogicalResult inferModule(ModuleOp module, const int hardware_generation,
|
||||
const TpuTilingFlags &tpu_tiling_flags) {
|
||||
// TODO(apaszke): Do layout assignment for scoped allocations too.
|
||||
for (Operation &op : *module.getBody()) {
|
||||
auto f = dyn_cast<func::FuncOp>(op);
|
||||
if (f == nullptr) {
|
||||
return module.emitOpError("Expected only FuncOps but found ") << op;
|
||||
}
|
||||
if (failed(inferFunc(f, hardware_generation))) {
|
||||
if (failed(inferFunc(f, hardware_generation, tpu_tiling_flags))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
@ -323,8 +354,10 @@ LogicalResult inferModule(ModuleOp module, const int hardware_generation) {
|
||||
|
||||
struct InferMemRefLayoutPass
|
||||
: public impl::InferMemRefLayoutPassBase<InferMemRefLayoutPass> {
|
||||
InferMemRefLayoutPass(int hardware_generation_) {
|
||||
InferMemRefLayoutPass(int hardware_generation_,
|
||||
const TpuTilingFlags &tpu_tiling_flags_) {
|
||||
hardware_generation = hardware_generation_;
|
||||
tpu_tiling_flags = tpu_tiling_flags_;
|
||||
}
|
||||
void runOnOperation() override {
|
||||
// Fail if hardware_generation has not been set from the default value.
|
||||
@ -333,7 +366,7 @@ struct InferMemRefLayoutPass
|
||||
return;
|
||||
}
|
||||
func::FuncOp func = getOperation();
|
||||
if (failed(inferFunc(func, hardware_generation))) {
|
||||
if (failed(inferFunc(func, hardware_generation, tpu_tiling_flags))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@ -341,8 +374,9 @@ struct InferMemRefLayoutPass
|
||||
};
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
|
||||
int hardware_generation) {
|
||||
return std::make_unique<InferMemRefLayoutPass>(hardware_generation);
|
||||
int hardware_generation, const TpuTilingFlags &tpu_tiling_flags_) {
|
||||
return std::make_unique<InferMemRefLayoutPass>(hardware_generation,
|
||||
tpu_tiling_flags_);
|
||||
}
|
||||
|
||||
} // namespace mlir::tpu
|
||||
|
@ -5,10 +5,12 @@
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
|
||||
namespace mlir::tpu {
|
||||
|
||||
FailureOr<MemRefType> inferMemref(MemRefType memref, int hardware_generation,
|
||||
const TpuTilingFlags& tpu_tiling_flags,
|
||||
int64_t leading_tile_rows = 0);
|
||||
|
||||
const std::string_view kLeadingTileRows = "leading_tile_rows";
|
||||
|
@ -1805,6 +1805,20 @@ class VectorLayoutInferer {
|
||||
}
|
||||
// Fall through.
|
||||
}
|
||||
if (auto store = dyn_cast<vector::StoreOp>(operand.getOwner())) {
|
||||
auto maybe_tiling = verifyMemoryTiling(
|
||||
store, getMemRefLayout(store.getBase()).getTiles(),
|
||||
store.getMemRefType().getRank(),
|
||||
store.getMemRefType().getElementTypeBitWidth());
|
||||
if (maybe_tiling) {
|
||||
auto tiling = *maybe_tiling;
|
||||
if (tiling ==
|
||||
nativeTiling(store.getMemRefType().getElementTypeBitWidth())) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// Fall through.
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
Loading…
x
Reference in New Issue
Block a user