Use default tiling in scratch buffers if XLA enables it

PiperOrigin-RevId: 650493683
This commit is contained in:
jax authors 2024-07-08 22:48:26 -07:00 committed by jax authors
parent d0394bfee5
commit 0da9b69285
7 changed files with 93 additions and 27 deletions

View File

@ -286,6 +286,9 @@ def _lower_tpu_kernel(
pipeline.run(module.operation)
dump_mlir(module, "post-hlo-conversion")
# Note: we don't pass the TpuTilingFlags here, since we don't know the
# tiling decisions made by the compiler / what flags are enabled at this
# point, so we assume everything can be tiled up to default tiling.
pipeline = [
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
]

View File

@ -679,11 +679,12 @@ def InferMemRefLayoutPass : Pass<"tpu-infer-memref-layout", "::mlir::func::FuncO
"::mlir::func::FuncDialect",
"::mlir::memref::MemRefDialect",
];
let constructor = "::mlir::tpu::createInferMemRefLayoutPass(-1)";
let constructor = "::mlir::tpu::createInferMemRefLayoutPass()";
let options = [
// If hardware_generation is not set, the default value of -1 will crash on
// runOnOperation.
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
Option<"tpu_tiling_flags", "tpu-tiling-flags", "::mlir::tpu::TpuTilingFlags", /*default=*/"::mlir::tpu::TpuTilingFlags{}", "">,
];
}

View File

@ -48,10 +48,16 @@ class TPUDialect;
namespace mlir {
namespace tpu {
struct TpuTilingFlags {
bool use_x16_large_second_minor = false;
bool use_x8_large_second_minor = false;
bool use_x4_large_second_minor = false;
};
std::pair<bool, bool> mightCommunicateBetweenChips(Operation* op);
std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
int hardware_generation = -1);
int hardware_generation = -1, const TpuTilingFlags &tpu_tiling_flags = {});
std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass();

View File

@ -188,9 +188,12 @@ FailureOr<Value> getInternalScratch(RewriteContext &ctx, OpBuilder &builder,
if (sublane_count > ctx.max_sublanes_in_scratch) {
return failure();
}
// We can omit tpu_tiling_flags here because, for internal scratch, the
// tiling does not matter (its shape is (N, 128)).
FAILUREOR_ASSIGN_OR_RETURN(
MemRefType scratch_ref_ty,
inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation));
inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation,
/*tpu_tiling_flags=*/{}));
return builder.create<tpu::GetInternalScratchOp>(loc, scratch_ref_ty)
.getResult();
}
@ -503,11 +506,14 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
return ctx.func.emitOpError(
"Not implemented: function has scratch_operands");
}
// We can omit tpu_tiling_flags here since we invoke inferMemref only for
// constant operands which are kernel parameters that will have their layouts
// overridden before the pass pipeline runs anyway.
FAILUREOR_ASSIGN_OR_RETURN(
MemRefType arg_type,
inferMemref(
MemRefType::get(value_ty.getShape(), value_ty.getElementType()),
ctx.hardware_generation));
ctx.hardware_generation, /*tpu_tiling_flags=*/{}));
const BlockArgument argument =
entry_block.insertArgument(entry_block.getNumArguments() - 1, arg_type,
UnknownLoc::get(ctx.getMLIRContext()));

View File

@ -56,17 +56,40 @@ SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
// Arguments:
// num_128s: A number of 128-element groups in the full operand.
// hardware_generation: An integer indicating the target TPU generation.
// tpu_tiling_flags: A struct of flags indicating which large tiling modes are
// enabled by XLA for memrefs.
// bitwidth: The bitwidth of the element type of the operand.
int getTilingFactor(const int num_128s, const int hardware_generation,
const TpuTilingFlags &tpu_tiling_flags,
const int8_t bitwidth) {
CHECK(llvm::isPowerOf2_32(bitwidth));
CHECK_LE(4, bitwidth);
CHECK_LE(bitwidth, 32);
const int packing = 32 / bitwidth;
const int min_tiling = (1 + (hardware_generation < 4)) * packing;
const int max_tiling = 8;
const int max_normal_tiling = 8;
const int large_tiling = [&] {
if (bitwidth == 4 && tpu_tiling_flags.use_x4_large_second_minor) {
return 64;
}
if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) {
return 32;
}
if (bitwidth == 16 && tpu_tiling_flags.use_x16_large_second_minor) {
return 16;
}
return 8;
}();
// Use large tiling if our operand is tall enough to fit at least one full
// tile.
if (large_tiling <= num_128s) {
return large_tiling;
}
int tiling = min_tiling;
while (tiling < std::min(num_128s, max_tiling)) {
while (tiling < std::min(num_128s, max_normal_tiling)) {
tiling *= 2;
}
return tiling;
@ -74,6 +97,7 @@ int getTilingFactor(const int num_128s, const int hardware_generation,
FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
const int hardware_generation,
const TpuTilingFlags &tpu_tiling_flags,
int64_t leading_tile_rows = 0) {
if (auto tiled_layout_attr =
dyn_cast<TiledLayoutAttr>(memref_ty.getLayout())) {
@ -97,7 +121,7 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
if (memref_ty.getRank() == 1) {
const int64_t leading_tile =
getTilingFactor(llvm::divideCeil(memref_ty.getShape().back(), 128),
hardware_generation, bitwidth) *
hardware_generation, tpu_tiling_flags, bitwidth) *
128;
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile})};
if (bitwidth != 32) {
@ -115,8 +139,8 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
const ArrayRef<int64_t> shape = memref_ty.getShape();
const int64_t second_minor = shape[shape.size() - 2];
if (leading_tile_rows == 0) {
leading_tile_rows =
getTilingFactor(second_minor, hardware_generation, bitwidth);
leading_tile_rows = getTilingFactor(second_minor, hardware_generation,
tpu_tiling_flags, bitwidth);
}
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile_rows, 128})};
if (bitwidth != 32) {
@ -160,6 +184,7 @@ LogicalResult checkTiles(MLIRContext *mlir_ctx,
FailureOr<MemRefType> inferMemref(MemRefType memref,
const int hardware_generation,
const TpuTilingFlags &tpu_tiling_flags,
int64_t leading_tile_rows) {
if (isa<SemaphoreType, DMASemaphoreType>(memref.getElementType())) {
const Attribute semaphore_mem = tpu::MemorySpaceAttr::get(
@ -180,9 +205,9 @@ FailureOr<MemRefType> inferMemref(MemRefType memref,
tpu::MemorySpaceAttr::get(memref.getContext(), MemorySpace::vmem);
const Attribute memory_space =
memref.getMemorySpace() == nullptr ? vmem : memref.getMemorySpace();
FAILUREOR_ASSIGN_OR_RETURN(
const TiledLayoutAttr layout,
inferLayout(memref, hardware_generation, leading_tile_rows));
FAILUREOR_ASSIGN_OR_RETURN(const TiledLayoutAttr layout,
inferLayout(memref, hardware_generation,
tpu_tiling_flags, leading_tile_rows));
const ArrayRef<xla::Tile> tiles = layout.getTiles();
if (failed(checkTiles(memref.getContext(), tiles))) {
@ -203,12 +228,14 @@ FailureOr<MemRefType> inferMemref(MemRefType memref,
memory_space);
}
LogicalResult inferOp(Operation &op, const int hardware_generation) {
LogicalResult inferOp(Operation &op, const int hardware_generation,
const TpuTilingFlags &tpu_tiling_flags) {
if (auto alloca_op = dyn_cast<memref::AllocaOp>(op)) {
TypedValue<MemRefType> arg = alloca_op.getResult();
const MemRefType memref_ty = alloca_op.getResult().getType();
FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation));
FAILUREOR_ASSIGN_OR_RETURN(
const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation, tpu_tiling_flags));
alloca_op.getResult().setType(new_memref_ty);
if (memref_ty != new_memref_ty) {
OpBuilder builder(alloca_op->getContext());
@ -223,8 +250,9 @@ LogicalResult inferOp(Operation &op, const int hardware_generation) {
} else if (auto alloca_op = dyn_cast<tpu::AllocaSemaphoreOp>(op)) {
TypedValue<MemRefType> arg = alloca_op.getResult();
const MemRefType memref_ty = alloca_op.getResult().getType();
FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation));
FAILUREOR_ASSIGN_OR_RETURN(
const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation, tpu_tiling_flags));
alloca_op.getResult().setType(new_memref_ty);
if (memref_ty != new_memref_ty) {
OpBuilder builder(alloca_op->getContext());
@ -240,7 +268,7 @@ LogicalResult inferOp(Operation &op, const int hardware_generation) {
for (Region &region : op.getRegions()) {
for (Block &block : region) {
for (Operation& op : block) {
if (failed(inferOp(op, hardware_generation))) {
if (failed(inferOp(op, hardware_generation, tpu_tiling_flags))) {
return failure();
}
}
@ -249,7 +277,8 @@ LogicalResult inferOp(Operation &op, const int hardware_generation) {
return success();
}
LogicalResult inferFunc(func::FuncOp f, const int hardware_generation) {
LogicalResult inferFunc(func::FuncOp f, const int hardware_generation,
const TpuTilingFlags &tpu_tiling_flags) {
if (!f.getBody().hasOneBlock()) {
return f.emitOpError("Functions should only have a single block");
}
@ -273,7 +302,8 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation) {
FAILUREOR_ASSIGN_OR_RETURN(
const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation, leading_tile_rows));
inferMemref(memref_ty, hardware_generation, tpu_tiling_flags,
leading_tile_rows));
arg.setType(new_memref_ty);
new_arg_types.push_back(arg.getType());
if (memref_ty != new_memref_ty) {
@ -292,7 +322,7 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation) {
f.setFunctionType(
builder.getAttr<FunctionType>(new_arg_types, f.getResultTypes()));
for (Operation &op : entry.getOperations()) {
if (failed(inferOp(op, hardware_generation))) {
if (failed(inferOp(op, hardware_generation, tpu_tiling_flags))) {
return failure();
}
}
@ -307,14 +337,15 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation) {
// Arguments:
// module: The MLIR module on which to perform the inference.
// hardware_generation: The TPU hardware generation to target.
LogicalResult inferModule(ModuleOp module, const int hardware_generation) {
LogicalResult inferModule(ModuleOp module, const int hardware_generation,
const TpuTilingFlags &tpu_tiling_flags) {
// TODO(apaszke): Do layout assignment for scoped allocations too.
for (Operation &op : *module.getBody()) {
auto f = dyn_cast<func::FuncOp>(op);
if (f == nullptr) {
return module.emitOpError("Expected only FuncOps but found ") << op;
}
if (failed(inferFunc(f, hardware_generation))) {
if (failed(inferFunc(f, hardware_generation, tpu_tiling_flags))) {
return failure();
}
}
@ -323,8 +354,10 @@ LogicalResult inferModule(ModuleOp module, const int hardware_generation) {
struct InferMemRefLayoutPass
: public impl::InferMemRefLayoutPassBase<InferMemRefLayoutPass> {
InferMemRefLayoutPass(int hardware_generation_) {
InferMemRefLayoutPass(int hardware_generation_,
const TpuTilingFlags &tpu_tiling_flags_) {
hardware_generation = hardware_generation_;
tpu_tiling_flags = tpu_tiling_flags_;
}
void runOnOperation() override {
// Fail if hardware_generation has not been set from the default value.
@ -333,7 +366,7 @@ struct InferMemRefLayoutPass
return;
}
func::FuncOp func = getOperation();
if (failed(inferFunc(func, hardware_generation))) {
if (failed(inferFunc(func, hardware_generation, tpu_tiling_flags))) {
signalPassFailure();
return;
}
@ -341,8 +374,9 @@ struct InferMemRefLayoutPass
};
std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
int hardware_generation) {
return std::make_unique<InferMemRefLayoutPass>(hardware_generation);
int hardware_generation, const TpuTilingFlags &tpu_tiling_flags_) {
return std::make_unique<InferMemRefLayoutPass>(hardware_generation,
tpu_tiling_flags_);
}
} // namespace mlir::tpu

View File

@ -5,10 +5,12 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
namespace mlir::tpu {
FailureOr<MemRefType> inferMemref(MemRefType memref, int hardware_generation,
const TpuTilingFlags& tpu_tiling_flags,
int64_t leading_tile_rows = 0);
const std::string_view kLeadingTileRows = "leading_tile_rows";

View File

@ -1805,6 +1805,20 @@ class VectorLayoutInferer {
}
// Fall through.
}
if (auto store = dyn_cast<vector::StoreOp>(operand.getOwner())) {
auto maybe_tiling = verifyMemoryTiling(
store, getMemRefLayout(store.getBase()).getTiles(),
store.getMemRefType().getRank(),
store.getMemRefType().getElementTypeBitWidth());
if (maybe_tiling) {
auto tiling = *maybe_tiling;
if (tiling ==
nativeTiling(store.getMemRefType().getElementTypeBitWidth())) {
continue;
}
}
// Fall through.
}
return false;
}
return true;