mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[Mosaic] Consider divisibility when doing large tiling
PiperOrigin-RevId: 728980108
This commit is contained in:
parent
262aab74f0
commit
37af0135b0
@ -37,50 +37,53 @@ namespace mlir::tpu {
|
||||
// Returns the number of lanes (usually 128) groups in a tile.
|
||||
//
|
||||
// Arguments:
|
||||
// num_lanes: A number of lanes in the full operand.
|
||||
// src_sublane: A number of lanes in the full operand.
|
||||
// hardware_generation: An integer indicating the target TPU generation.
|
||||
// sublane_count: The number of sublanes.
|
||||
// tiling_sublane: The number of sublane in the target shape.
|
||||
// tpu_tiling_flags: A struct of flags indicating which large tiling modes are
|
||||
// enabled by XLA for memrefs.
|
||||
// bitwidth: The bitwidth of the element type of the operand.
|
||||
// is_kernel_argument: Whether the operand is a kernel argument.
|
||||
int getTilingFactor(const int num_lanes, const int hardware_generation,
|
||||
const int64_t sublane_count,
|
||||
int getTilingFactor(const int src_sublane,
|
||||
const int hardware_generation,
|
||||
const int64_t tiling_sublane,
|
||||
const TpuTilingFlags &tpu_tiling_flags,
|
||||
const int8_t bitwidth,
|
||||
const bool is_kernel_argument) {
|
||||
const int8_t bitwidth, const bool is_kernel_argument) {
|
||||
CHECK(llvm::isPowerOf2_32(bitwidth));
|
||||
CHECK_LE(4, bitwidth);
|
||||
CHECK_LE(bitwidth, 32);
|
||||
const int packing = 32 / bitwidth;
|
||||
const int min_tiling = (1 + (hardware_generation < 4)) * packing;
|
||||
const int max_normal_tiling = sublane_count;
|
||||
const int max_normal_tiling = tiling_sublane;
|
||||
|
||||
const int large_tiling = [&] {
|
||||
int large_tiling = [&] {
|
||||
if (bitwidth == 4 && tpu_tiling_flags.use_x4_large_second_minor) {
|
||||
return sublane_count * 8;
|
||||
return tiling_sublane * 8;
|
||||
}
|
||||
if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) {
|
||||
return sublane_count * 4;
|
||||
return tiling_sublane * 4;
|
||||
}
|
||||
// 16-bit values are generally always possible to relayout on the fly in v6,
|
||||
// so we allow large 2nd minor tiling whenever possible. We can't do this
|
||||
// for kernel arguments, because the layout of those is controlled by XLA.
|
||||
if (bitwidth == 16 && (tpu_tiling_flags.use_x16_large_second_minor ||
|
||||
(!is_kernel_argument && hardware_generation >= 6))) {
|
||||
return sublane_count * 2;
|
||||
return tiling_sublane * 2;
|
||||
}
|
||||
return sublane_count;
|
||||
return tiling_sublane;
|
||||
}();
|
||||
|
||||
bool is_divisible = src_sublane % large_tiling == 0;
|
||||
large_tiling = is_divisible ? large_tiling : tiling_sublane;
|
||||
|
||||
// Use large tiling if our operand is tall enough to fit at least one full
|
||||
// tile.
|
||||
if (large_tiling <= num_lanes) {
|
||||
if (large_tiling <= src_sublane) {
|
||||
return large_tiling;
|
||||
}
|
||||
|
||||
int tiling = min_tiling;
|
||||
while (tiling < std::min(num_lanes, max_normal_tiling)) {
|
||||
while (tiling < std::min(src_sublane, max_normal_tiling)) {
|
||||
tiling *= 2;
|
||||
}
|
||||
return tiling;
|
||||
@ -123,11 +126,12 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
|
||||
const auto [sublane_count, lane_count] = target_shape;
|
||||
// Infer the layout
|
||||
if (memref_ty.getRank() == 1) {
|
||||
auto src_sublane =
|
||||
llvm::divideCeil(memref_ty.getShape().back(), lane_count);
|
||||
const int64_t leading_tile =
|
||||
getTilingFactor(
|
||||
llvm::divideCeil(memref_ty.getShape().back(), lane_count),
|
||||
hardware_generation, sublane_count, tpu_tiling_flags, bitwidth,
|
||||
is_kernel_argument) *
|
||||
getTilingFactor(src_sublane, hardware_generation,
|
||||
sublane_count, tpu_tiling_flags, bitwidth,
|
||||
is_kernel_argument) *
|
||||
lane_count;
|
||||
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile})};
|
||||
if (bitwidth != 32) {
|
||||
@ -143,11 +147,12 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
|
||||
|
||||
// memref.getRank() > 1
|
||||
const ArrayRef<int64_t> shape = memref_ty.getShape();
|
||||
const int64_t second_minor = shape[shape.size() - 2];
|
||||
|
||||
const int64_t src_sublane = shape[shape.size() - 2];
|
||||
if (leading_tile_rows == 0) {
|
||||
leading_tile_rows =
|
||||
getTilingFactor(second_minor, hardware_generation, sublane_count,
|
||||
tpu_tiling_flags, bitwidth, is_kernel_argument);
|
||||
leading_tile_rows = getTilingFactor(
|
||||
src_sublane, hardware_generation, sublane_count,
|
||||
tpu_tiling_flags, bitwidth, is_kernel_argument);
|
||||
}
|
||||
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile_rows, lane_count})};
|
||||
if (bitwidth != 32) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user