[Mosaic] Parameterize the number of lanes and sublanes in TPU dialects.

PiperOrigin-RevId: 684392184
This commit is contained in:
jax authors 2024-10-10 04:27:45 -07:00
parent 351187d9da
commit 81a95f78b9
7 changed files with 99 additions and 96 deletions

View File

@ -274,6 +274,7 @@ mlir.register_lowering(tpu_custom_call_p, _tpu_custom_call_lowering,
def _lower_tpu_kernel(
module: ir.Module,
hardware_generation: int,
target_shape: tuple[int, int],
) -> ir.Module:
"""Runs MLIR passes lowering the given module to an MLIR module.
@ -283,6 +284,7 @@ def _lower_tpu_kernel(
Args:
module: The MLIR module to lower.
hardware_generation: The TPU hardware generation to target.
target_shape: The target shape of (sublane_count, lane_count).
Returns:
An MLIR module implementing the kernel.
@ -312,11 +314,16 @@ def _lower_tpu_kernel(
pipeline.run(module.operation)
dump_mlir(module, "post-hlo-conversion")
sl_cnt, l_cnt = target_shape
# Note: we don't pass the TpuTilingFlags here, since we don't know the
# tiling decisions made by the compiler / what flags are enabled at this
# point, so we assume everything can be tiled up to default tiling.
pipeline = [
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
"func.func(tpu-infer-memref-layout{"
f" hardware-generation={hardware_generation}"
f" sublane-count={sl_cnt}"
f" lane-count={l_cnt}"
"})"
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
@ -357,14 +364,16 @@ def _lower_tpu_kernel(
dump_mlir(module, "post-canonicalize-mosaic")
pipeline = [
"func.func(tpu-infer-vector-layout{sublane-count=8 lane-count=128})",
(
"func.func(tpu-infer-vector-layout{"
f" sublane-count={sl_cnt} lane-count={l_cnt}"
"})"
),
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
dump_mlir(module, "post-infer-vector-layout")
sl_cnt = 8
l_cnt = 128
mxu_size = 128 if hardware_generation < 6 else 256
pipeline = [
"func.func(tpu-apply-vector-layout{"
@ -414,7 +423,10 @@ def _lower_mosaic_module_to_asm(
"tpu_custom_call cannot be lowered on a machine without TPUs "
"when mosaic_use_python_pipeline=True.")
hardware_generation = int(device_kind[len("TPU v")])
module = _lower_tpu_kernel(module, hardware_generation)
# TODO(b/369418606): Infer the target shape from the hardware generation.
module = _lower_tpu_kernel(
module, hardware_generation, target_shape=(8, 128)
)
needs_hlo_passes = False
needs_layout_passes = False
prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects

View File

@ -38,18 +38,8 @@ class TPU_Attr<string name, string mnemonic_, list<Trait> traits = []>
let mnemonic = mnemonic_;
}
def TPU_Vreg : Type<
And<[IsVectorTypePred,
Or<[
And<[
CPred<"llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{8, 128}">,
CPred<"llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth() == 32">
]>,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{"
"8, 128, 32 / ::llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth()}">,
]>
]>,
"native-sized vreg", "::mlir::VectorType">;
// TODO(b/369418606): Find out the way to verify vreg size.
def TPU_Vreg : Type<IsVectorTypePred, "native-sized vreg", "::mlir::VectorType">;
class TPU_Type<string name, string mnemonic_, list<Trait> traits = []>
: TypeDef<TPU_Dialect, name, traits> {
@ -738,6 +728,8 @@ def InferMemRefLayoutPass : Pass<"tpu-infer-memref-layout", "::mlir::func::FuncO
// If hardware_generation is not set, the default value of -1 will crash on
// runOnOperation.
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
Option<"lane_count", "lane-count", "int", /*default=*/"128", "">,
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
Option<"tpu_tiling_flags", "tpu-tiling-flags", "::mlir::tpu::TpuTilingFlags", /*default=*/"::mlir::tpu::TpuTilingFlags{}", "">,
];
}

View File

@ -68,13 +68,15 @@ struct ApplyVectorLayoutContext {
std::pair<bool, bool> mightCommunicateBetweenChips(Operation* op);
std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
int hardware_generation = -1, const TpuTilingFlags &tpu_tiling_flags = {});
int hardware_generation = -1,
std::array<int64_t, 2> target_shape = {8, 128},
const TpuTilingFlags &tpu_tiling_flags = {});
std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass(
int hardware_generation = -1);
std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
int lane_count = 128, int sublane_count = 8);
std::array<int64_t, 2> target_shape = {8, 128});
std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
const ApplyVectorLayoutContext &ctx = ApplyVectorLayoutContext{});

View File

@ -164,7 +164,7 @@ FailureOr<TypedValue<MemRefType>> getInternalScratch(
FAILUREOR_ASSIGN_OR_RETURN(
MemRefType scratch_ref_ty,
inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation,
/*tpu_tiling_flags=*/{}, sublane_tiling));
ctx.target_shape, /*tpu_tiling_flags=*/{}, sublane_tiling));
return builder.create<tpu::GetInternalScratchOp>(loc, scratch_ref_ty)
.getResult();
}
@ -490,7 +490,7 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx, func::FuncOp func,
MemRefType arg_type,
inferMemref(
MemRefType::get(value_ty.getShape(), value_ty.getElementType()),
ctx.hardware_generation, /*tpu_tiling_flags=*/{}));
ctx.hardware_generation, ctx.target_shape, /*tpu_tiling_flags=*/{}));
const BlockArgument argument = entry_block.insertArgument(
entry_block.getNumArguments() - 1, arg_type, UnknownLoc::get(mlir_ctx));
const FunctionType func_ty = func.getFunctionType();
@ -5821,8 +5821,8 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
if (try_replicate_rows && packing == 1 &&
*(vregs.dimensions().end() - 2) == 1 &&
src.offsets() == LayoutOffsets{0, 0} &&
src.tiling() == std::array<int64_t, 2>{1, 128} &&
dst_tiling == std::array<int64_t, 2>{8, 128}) {
src.tiling() == std::array<int64_t, 2>{1, ctx.target_shape[1]} &&
dst_tiling == ctx.target_shape) {
xla::Array<Value> retiled(dst_tiles_shape);
retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
@ -5839,9 +5839,9 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
return std::pair(dst, std::move(retiled));
}
// (8,128) -> (8 * packing,128) tiling change for packed type.
if (bitwidth < 32 && 32 % bitwidth == 0 &&
src_tiling == std::array<int64_t, 2>{8, 128} &&
dst_tiling == std::array<int64_t, 2>{8 * dst.packing(), 128}) {
if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * dst.packing(),
ctx.target_shape[1]}) {
// Note: for int4, retiling with scratch is always faster.
if (bitwidth != 4 || !has_enough_scratch) {
xla::Array<Value> retiled(dst_tiles_shape);
@ -5883,8 +5883,8 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
// match corresponding elements without shifting. It's just that
// the tiles are not adjacent (no contiguous vreg slice).
if (bitwidth < 32 && 32 % bitwidth == 0 &&
src_tiling == std::array<int64_t, 2>{1, 128 * packing} &&
dst_tiling == std::array<int64_t, 2>{packing, 128}) {
src_tiling == std::array<int64_t, 2>{1, ctx.target_shape[1] * packing} &&
dst_tiling == std::array<int64_t, 2>{packing, ctx.target_shape[1]}) {
// To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of
// 4 sublanes and 2 lanes (this is convenient for to keep the example small
// yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling.

View File

@ -1,6 +1,7 @@
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h"
#include <algorithm>
#include <array>
#include <cstdint>
#include <cstdlib>
#include <memory>
@ -33,16 +34,17 @@ namespace mlir::tpu {
#define GEN_PASS_DEF_INFERMEMREFLAYOUTPASS
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
// Returns the number of 128-element groups in a tile.
// Returns the number of lanes (usually 128) groups in a tile.
//
// Arguments:
// num_128s: A number of 128-element groups in the full operand.
// num_lanes: A number of lanes in the full operand.
// hardware_generation: An integer indicating the target TPU generation.
// sublane_count: The number of sublanes.
// tpu_tiling_flags: A struct of flags indicating which large tiling modes are
// enabled by XLA for memrefs.
// bitwidth: The bitwidth of the element type of the operand.
int getTilingFactor(const int num_128s, const int hardware_generation,
int getTilingFactor(const int num_lanes, const int hardware_generation,
const int64_t sublane_count,
const TpuTilingFlags &tpu_tiling_flags,
const int8_t bitwidth) {
CHECK(llvm::isPowerOf2_32(bitwidth));
@ -50,29 +52,29 @@ int getTilingFactor(const int num_128s, const int hardware_generation,
CHECK_LE(bitwidth, 32);
const int packing = 32 / bitwidth;
const int min_tiling = (1 + (hardware_generation < 4)) * packing;
const int max_normal_tiling = 8;
const int max_normal_tiling = sublane_count;
const int large_tiling = [&] {
if (bitwidth == 4 && tpu_tiling_flags.use_x4_large_second_minor) {
return 64;
return sublane_count * 8;
}
if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) {
return 32;
return sublane_count * 4;
}
if (bitwidth == 16 && tpu_tiling_flags.use_x16_large_second_minor) {
return 16;
return sublane_count * 2;
}
return 8;
return sublane_count;
}();
// Use large tiling if our operand is tall enough to fit at least one full
// tile.
if (large_tiling <= num_128s) {
if (large_tiling <= num_lanes) {
return large_tiling;
}
int tiling = min_tiling;
while (tiling < std::min(num_128s, max_normal_tiling)) {
while (tiling < std::min(num_lanes, max_normal_tiling)) {
tiling *= 2;
}
return tiling;
@ -80,6 +82,7 @@ int getTilingFactor(const int num_128s, const int hardware_generation,
FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
const int hardware_generation,
std::array<int64_t, 2> target_shape,
const TpuTilingFlags &tpu_tiling_flags,
int64_t leading_tile_rows = 0) {
if (auto tiled_layout_attr =
@ -100,12 +103,14 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
"Invalid element type for memref");
}
const int8_t bitwidth = memref_ty.getElementTypeBitWidth();
const auto [sublane_count, lane_count] = target_shape;
// Infer the layout
if (memref_ty.getRank() == 1) {
const int64_t leading_tile =
getTilingFactor(llvm::divideCeil(memref_ty.getShape().back(), 128),
hardware_generation, tpu_tiling_flags, bitwidth) *
128;
getTilingFactor(
llvm::divideCeil(memref_ty.getShape().back(), lane_count),
hardware_generation, sublane_count, tpu_tiling_flags, bitwidth) *
lane_count;
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile})};
if (bitwidth != 32) {
if (!llvm::has_single_bit<unsigned>(bitwidth) || bitwidth > 32) {
@ -113,7 +118,7 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
"Unsupported bitwidth: ")
<< bitwidth;
}
tiles.append({xla::Tile({128}), xla::Tile({32 / bitwidth, 1})});
tiles.append({xla::Tile({lane_count}), xla::Tile({32 / bitwidth, 1})});
}
return TiledLayoutAttr::get(memref_ty.getContext(), tiles, {1});
}
@ -122,10 +127,11 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
const ArrayRef<int64_t> shape = memref_ty.getShape();
const int64_t second_minor = shape[shape.size() - 2];
if (leading_tile_rows == 0) {
leading_tile_rows = getTilingFactor(second_minor, hardware_generation,
tpu_tiling_flags, bitwidth);
leading_tile_rows =
getTilingFactor(second_minor, hardware_generation, sublane_count,
tpu_tiling_flags, bitwidth);
}
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile_rows, 128})};
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile_rows, lane_count})};
if (bitwidth != 32) {
if (!llvm::has_single_bit<unsigned>(bitwidth) || bitwidth > 32) {
return emitError(UnknownLoc::get(memref_ty.getContext()),
@ -134,7 +140,8 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
}
tiles.push_back(xla::Tile({32 / bitwidth, 1}));
}
auto tile_strides = ComputeTileStrides(memref_ty, {leading_tile_rows, 128});
auto tile_strides =
ComputeTileStrides(memref_ty, {leading_tile_rows, lane_count});
return TiledLayoutAttr::get(memref_ty.getContext(), tiles, tile_strides);
}
return emitError(UnknownLoc::get(memref_ty.getContext()),
@ -167,6 +174,7 @@ LogicalResult checkTiles(MLIRContext *mlir_ctx,
FailureOr<MemRefType> inferMemref(MemRefType memref,
const int hardware_generation,
std::array<int64_t, 2> target_shape,
const TpuTilingFlags &tpu_tiling_flags,
int64_t leading_tile_rows) {
if (isa<SemaphoreType, DMASemaphoreType>(memref.getElementType())) {
@ -188,9 +196,10 @@ FailureOr<MemRefType> inferMemref(MemRefType memref,
tpu::MemorySpaceAttr::get(memref.getContext(), MemorySpace::vmem);
const Attribute memory_space =
memref.getMemorySpace() == nullptr ? vmem : memref.getMemorySpace();
FAILUREOR_ASSIGN_OR_RETURN(const TiledLayoutAttr layout,
inferLayout(memref, hardware_generation,
tpu_tiling_flags, leading_tile_rows));
FAILUREOR_ASSIGN_OR_RETURN(
const TiledLayoutAttr layout,
inferLayout(memref, hardware_generation, target_shape, tpu_tiling_flags,
leading_tile_rows));
const ArrayRef<xla::Tile> tiles = layout.getTiles();
if (failed(checkTiles(memref.getContext(), tiles))) {
@ -212,13 +221,14 @@ FailureOr<MemRefType> inferMemref(MemRefType memref,
}
LogicalResult inferOp(Operation &op, const int hardware_generation,
std::array<int64_t, 2> target_shape,
const TpuTilingFlags &tpu_tiling_flags) {
if (auto alloca_op = dyn_cast<memref::AllocaOp>(op)) {
TypedValue<MemRefType> arg = alloca_op.getResult();
const MemRefType memref_ty = alloca_op.getResult().getType();
FAILUREOR_ASSIGN_OR_RETURN(
const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation, tpu_tiling_flags));
FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation,
target_shape, tpu_tiling_flags));
alloca_op.getResult().setType(new_memref_ty);
if (memref_ty != new_memref_ty) {
OpBuilder builder(alloca_op->getContext());
@ -233,9 +243,9 @@ LogicalResult inferOp(Operation &op, const int hardware_generation,
} else if (auto alloca_op = dyn_cast<tpu::AllocaSemaphoreOp>(op)) {
TypedValue<MemRefType> arg = alloca_op.getResult();
const MemRefType memref_ty = alloca_op.getResult().getType();
FAILUREOR_ASSIGN_OR_RETURN(
const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation, tpu_tiling_flags));
FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation,
target_shape, tpu_tiling_flags));
alloca_op.getResult().setType(new_memref_ty);
if (memref_ty != new_memref_ty) {
OpBuilder builder(alloca_op->getContext());
@ -251,7 +261,8 @@ LogicalResult inferOp(Operation &op, const int hardware_generation,
for (Region &region : op.getRegions()) {
for (Block &block : region) {
for (Operation& op : block) {
if (failed(inferOp(op, hardware_generation, tpu_tiling_flags))) {
if (failed(inferOp(op, hardware_generation, target_shape,
tpu_tiling_flags))) {
return failure();
}
}
@ -261,6 +272,7 @@ LogicalResult inferOp(Operation &op, const int hardware_generation,
}
LogicalResult inferFunc(func::FuncOp f, const int hardware_generation,
std::array<int64_t, 2> target_shape,
const TpuTilingFlags &tpu_tiling_flags) {
if (!f.getBody().hasOneBlock()) {
return f.emitOpError("Functions should only have a single block");
@ -285,8 +297,8 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation,
FAILUREOR_ASSIGN_OR_RETURN(
const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation, tpu_tiling_flags,
leading_tile_rows));
inferMemref(memref_ty, hardware_generation, target_shape,
tpu_tiling_flags, leading_tile_rows));
arg.setType(new_memref_ty);
new_arg_types.push_back(arg.getType());
if (memref_ty != new_memref_ty) {
@ -305,30 +317,8 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation,
f.setFunctionType(
builder.getAttr<FunctionType>(new_arg_types, f.getResultTypes()));
for (Operation &op : entry.getOperations()) {
if (failed(inferOp(op, hardware_generation, tpu_tiling_flags))) {
return failure();
}
}
return success();
}
// Infers the layout and memory space attributes of function memref arguments.
//
// In the future we should require those annotations from Mosaic users, but it's
// best to keep them internal for as long as they are under development.
//
// Arguments:
// module: The MLIR module on which to perform the inference.
// hardware_generation: The TPU hardware generation to target.
LogicalResult inferModule(ModuleOp module, const int hardware_generation,
const TpuTilingFlags &tpu_tiling_flags) {
// TODO(apaszke): Do layout assignment for scoped allocations too.
for (Operation &op : *module.getBody()) {
auto f = dyn_cast<func::FuncOp>(op);
if (f == nullptr) {
return module.emitOpError("Expected only FuncOps but found ") << op;
}
if (failed(inferFunc(f, hardware_generation, tpu_tiling_flags))) {
if (failed(
inferOp(op, hardware_generation, target_shape, tpu_tiling_flags))) {
return failure();
}
}
@ -338,8 +328,11 @@ LogicalResult inferModule(ModuleOp module, const int hardware_generation,
struct InferMemRefLayoutPass
: public impl::InferMemRefLayoutPassBase<InferMemRefLayoutPass> {
InferMemRefLayoutPass(int hardware_generation_,
std::array<int64_t, 2> target_shape_,
const TpuTilingFlags &tpu_tiling_flags_) {
hardware_generation = hardware_generation_;
sublane_count = target_shape_[0];
lane_count = target_shape_[1];
tpu_tiling_flags = tpu_tiling_flags_;
}
void runOnOperation() override {
@ -349,7 +342,8 @@ struct InferMemRefLayoutPass
return;
}
func::FuncOp func = getOperation();
if (failed(inferFunc(func, hardware_generation, tpu_tiling_flags))) {
if (failed(inferFunc(func, hardware_generation, {sublane_count, lane_count},
tpu_tiling_flags))) {
signalPassFailure();
return;
}
@ -357,9 +351,10 @@ struct InferMemRefLayoutPass
};
std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
int hardware_generation, const TpuTilingFlags &tpu_tiling_flags_) {
return std::make_unique<InferMemRefLayoutPass>(hardware_generation,
tpu_tiling_flags_);
int hardware_generation, std::array<int64_t, 2> target_shape,
const TpuTilingFlags &tpu_tiling_flags_) {
return std::make_unique<InferMemRefLayoutPass>(
hardware_generation, target_shape, tpu_tiling_flags_);
}
} // namespace mlir::tpu

View File

@ -1,7 +1,9 @@
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_
#include <string>
#include <array>
#include <cstdint>
#include <string_view>
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
@ -10,6 +12,7 @@
namespace mlir::tpu {
FailureOr<MemRefType> inferMemref(MemRefType memref, int hardware_generation,
std::array<int64_t, 2> target_shape,
const TpuTilingFlags& tpu_tiling_flags,
int64_t leading_tile_rows = 0);

View File

@ -21,9 +21,7 @@ limitations under the License.
#include <optional>
#include <ostream>
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVectorExtras.h"
@ -34,7 +32,6 @@ limitations under the License.
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
@ -42,6 +39,7 @@ limitations under the License.
#include "mlir/Support/LogicalResult.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/types/span.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
@ -49,6 +47,7 @@ limitations under the License.
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/include/mlir/IR/OpDefinition.h"
#include "mlir/include/mlir/IR/Visitors.h"
#include "mlir/include/mlir/Pass/Pass.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "xla/layout.h"
@ -2043,9 +2042,9 @@ class VectorLayoutInferer {
struct InferVectorLayoutPass
: public impl::InferVectorLayoutPassBase<InferVectorLayoutPass> {
InferVectorLayoutPass(int lane_count, int sublane_count) {
this->sublane_count = sublane_count;
this->lane_count = lane_count;
InferVectorLayoutPass(std::array<int64_t, 2> target_shape) {
this->sublane_count = target_shape[0];
this->lane_count = target_shape[1];
}
void runOnOperation() override {
func::FuncOp func = getOperation();
@ -2059,8 +2058,8 @@ struct InferVectorLayoutPass
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
int lane_count, int sublane_count) {
return std::make_unique<InferVectorLayoutPass>(lane_count, sublane_count);
std::array<int64_t, 2> target_shape) {
return std::make_unique<InferVectorLayoutPass>(target_shape);
}
} // namespace mlir::tpu