From 37af0135b0f63ed75e745e3415b69f5dfebd49a1 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 19 Feb 2025 23:55:26 -0800 Subject: [PATCH] [Mosaic] Consider divisibility when doing large tiling PiperOrigin-RevId: 728980108 --- .../tpu/transforms/infer_memref_layout.cc | 49 ++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index 05667a847..d63c13bde 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -37,50 +37,53 @@ namespace mlir::tpu { // Returns the number of lanes (usually 128) groups in a tile. // // Arguments: -// num_lanes: A number of lanes in the full operand. +// src_sublane: A number of lanes in the full operand. // hardware_generation: An integer indicating the target TPU generation. -// sublane_count: The number of sublanes. +// tiling_sublane: The number of sublane in the target shape. // tpu_tiling_flags: A struct of flags indicating which large tiling modes are // enabled by XLA for memrefs. // bitwidth: The bitwidth of the element type of the operand. // is_kernel_argument: Whether the operand is a kernel argument. -int getTilingFactor(const int num_lanes, const int hardware_generation, - const int64_t sublane_count, +int getTilingFactor(const int src_sublane, + const int hardware_generation, + const int64_t tiling_sublane, const TpuTilingFlags &tpu_tiling_flags, - const int8_t bitwidth, - const bool is_kernel_argument) { + const int8_t bitwidth, const bool is_kernel_argument) { CHECK(llvm::isPowerOf2_32(bitwidth)); CHECK_LE(4, bitwidth); CHECK_LE(bitwidth, 32); const int packing = 32 / bitwidth; const int min_tiling = (1 + (hardware_generation < 4)) * packing; - const int max_normal_tiling = sublane_count; + const int max_normal_tiling = tiling_sublane; - const int large_tiling = [&] { + int large_tiling = [&] { if (bitwidth == 4 && tpu_tiling_flags.use_x4_large_second_minor) { - return sublane_count * 8; + return tiling_sublane * 8; } if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) { - return sublane_count * 4; + return tiling_sublane * 4; } // 16-bit values are generally always possible to relayout on the fly in v6, // so we allow large 2nd minor tiling whenever possible. We can't do this // for kernel arguments, because the layout of those is controlled by XLA. if (bitwidth == 16 && (tpu_tiling_flags.use_x16_large_second_minor || (!is_kernel_argument && hardware_generation >= 6))) { - return sublane_count * 2; + return tiling_sublane * 2; } - return sublane_count; + return tiling_sublane; }(); + bool is_divisible = src_sublane % large_tiling == 0; + large_tiling = is_divisible ? large_tiling : tiling_sublane; + // Use large tiling if our operand is tall enough to fit at least one full // tile. - if (large_tiling <= num_lanes) { + if (large_tiling <= src_sublane) { return large_tiling; } int tiling = min_tiling; - while (tiling < std::min(num_lanes, max_normal_tiling)) { + while (tiling < std::min(src_sublane, max_normal_tiling)) { tiling *= 2; } return tiling; @@ -123,11 +126,12 @@ FailureOr inferLayout(MemRefType memref_ty, const auto [sublane_count, lane_count] = target_shape; // Infer the layout if (memref_ty.getRank() == 1) { + auto src_sublane = + llvm::divideCeil(memref_ty.getShape().back(), lane_count); const int64_t leading_tile = - getTilingFactor( - llvm::divideCeil(memref_ty.getShape().back(), lane_count), - hardware_generation, sublane_count, tpu_tiling_flags, bitwidth, - is_kernel_argument) * + getTilingFactor(src_sublane, hardware_generation, + sublane_count, tpu_tiling_flags, bitwidth, + is_kernel_argument) * lane_count; SmallVector tiles{xla::Tile({leading_tile})}; if (bitwidth != 32) { @@ -143,11 +147,12 @@ FailureOr inferLayout(MemRefType memref_ty, // memref.getRank() > 1 const ArrayRef shape = memref_ty.getShape(); - const int64_t second_minor = shape[shape.size() - 2]; + + const int64_t src_sublane = shape[shape.size() - 2]; if (leading_tile_rows == 0) { - leading_tile_rows = - getTilingFactor(second_minor, hardware_generation, sublane_count, - tpu_tiling_flags, bitwidth, is_kernel_argument); + leading_tile_rows = getTilingFactor( + src_sublane, hardware_generation, sublane_count, + tpu_tiling_flags, bitwidth, is_kernel_argument); } SmallVector tiles{xla::Tile({leading_tile_rows, lane_count})}; if (bitwidth != 32) {