[Mosaic] Consider divisibility when doing large tiling

PiperOrigin-RevId: 728980108
This commit is contained in:
jax authors 2025-02-19 23:55:26 -08:00
parent 262aab74f0
commit 37af0135b0

View File

@ -37,50 +37,53 @@ namespace mlir::tpu {
// Returns the number of lanes (usually 128) groups in a tile.
//
// Arguments:
// num_lanes: A number of lanes in the full operand.
// src_sublane: A number of lanes in the full operand.
// hardware_generation: An integer indicating the target TPU generation.
// sublane_count: The number of sublanes.
// tiling_sublane: The number of sublane in the target shape.
// tpu_tiling_flags: A struct of flags indicating which large tiling modes are
// enabled by XLA for memrefs.
// bitwidth: The bitwidth of the element type of the operand.
// is_kernel_argument: Whether the operand is a kernel argument.
int getTilingFactor(const int num_lanes, const int hardware_generation,
const int64_t sublane_count,
int getTilingFactor(const int src_sublane,
const int hardware_generation,
const int64_t tiling_sublane,
const TpuTilingFlags &tpu_tiling_flags,
const int8_t bitwidth,
const bool is_kernel_argument) {
const int8_t bitwidth, const bool is_kernel_argument) {
CHECK(llvm::isPowerOf2_32(bitwidth));
CHECK_LE(4, bitwidth);
CHECK_LE(bitwidth, 32);
const int packing = 32 / bitwidth;
const int min_tiling = (1 + (hardware_generation < 4)) * packing;
const int max_normal_tiling = sublane_count;
const int max_normal_tiling = tiling_sublane;
const int large_tiling = [&] {
int large_tiling = [&] {
if (bitwidth == 4 && tpu_tiling_flags.use_x4_large_second_minor) {
return sublane_count * 8;
return tiling_sublane * 8;
}
if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) {
return sublane_count * 4;
return tiling_sublane * 4;
}
// 16-bit values are generally always possible to relayout on the fly in v6,
// so we allow large 2nd minor tiling whenever possible. We can't do this
// for kernel arguments, because the layout of those is controlled by XLA.
if (bitwidth == 16 && (tpu_tiling_flags.use_x16_large_second_minor ||
(!is_kernel_argument && hardware_generation >= 6))) {
return sublane_count * 2;
return tiling_sublane * 2;
}
return sublane_count;
return tiling_sublane;
}();
bool is_divisible = src_sublane % large_tiling == 0;
large_tiling = is_divisible ? large_tiling : tiling_sublane;
// Use large tiling if our operand is tall enough to fit at least one full
// tile.
if (large_tiling <= num_lanes) {
if (large_tiling <= src_sublane) {
return large_tiling;
}
int tiling = min_tiling;
while (tiling < std::min(num_lanes, max_normal_tiling)) {
while (tiling < std::min(src_sublane, max_normal_tiling)) {
tiling *= 2;
}
return tiling;
@ -123,11 +126,12 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
const auto [sublane_count, lane_count] = target_shape;
// Infer the layout
if (memref_ty.getRank() == 1) {
auto src_sublane =
llvm::divideCeil(memref_ty.getShape().back(), lane_count);
const int64_t leading_tile =
getTilingFactor(
llvm::divideCeil(memref_ty.getShape().back(), lane_count),
hardware_generation, sublane_count, tpu_tiling_flags, bitwidth,
is_kernel_argument) *
getTilingFactor(src_sublane, hardware_generation,
sublane_count, tpu_tiling_flags, bitwidth,
is_kernel_argument) *
lane_count;
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile})};
if (bitwidth != 32) {
@ -143,11 +147,12 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
// memref.getRank() > 1
const ArrayRef<int64_t> shape = memref_ty.getShape();
const int64_t second_minor = shape[shape.size() - 2];
const int64_t src_sublane = shape[shape.size() - 2];
if (leading_tile_rows == 0) {
leading_tile_rows =
getTilingFactor(second_minor, hardware_generation, sublane_count,
tpu_tiling_flags, bitwidth, is_kernel_argument);
leading_tile_rows = getTilingFactor(
src_sublane, hardware_generation, sublane_count,
tpu_tiling_flags, bitwidth, is_kernel_argument);
}
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile_rows, lane_count})};
if (bitwidth != 32) {