mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic] Parameterize the number of lanes and sublanes in TPU dialects.
PiperOrigin-RevId: 684392184
This commit is contained in:
parent
351187d9da
commit
81a95f78b9
@ -274,6 +274,7 @@ mlir.register_lowering(tpu_custom_call_p, _tpu_custom_call_lowering,
|
||||
def _lower_tpu_kernel(
|
||||
module: ir.Module,
|
||||
hardware_generation: int,
|
||||
target_shape: tuple[int, int],
|
||||
) -> ir.Module:
|
||||
"""Runs MLIR passes lowering the given module to an MLIR module.
|
||||
|
||||
@ -283,6 +284,7 @@ def _lower_tpu_kernel(
|
||||
Args:
|
||||
module: The MLIR module to lower.
|
||||
hardware_generation: The TPU hardware generation to target.
|
||||
target_shape: The target shape of (sublane_count, lane_count).
|
||||
|
||||
Returns:
|
||||
An MLIR module implementing the kernel.
|
||||
@ -312,11 +314,16 @@ def _lower_tpu_kernel(
|
||||
pipeline.run(module.operation)
|
||||
dump_mlir(module, "post-hlo-conversion")
|
||||
|
||||
sl_cnt, l_cnt = target_shape
|
||||
# Note: we don't pass the TpuTilingFlags here, since we don't know the
|
||||
# tiling decisions made by the compiler / what flags are enabled at this
|
||||
# point, so we assume everything can be tiled up to default tiling.
|
||||
pipeline = [
|
||||
f"func.func(tpu-infer-memref-layout{{hardware-generation={hardware_generation}}})"
|
||||
"func.func(tpu-infer-memref-layout{"
|
||||
f" hardware-generation={hardware_generation}"
|
||||
f" sublane-count={sl_cnt}"
|
||||
f" lane-count={l_cnt}"
|
||||
"})"
|
||||
]
|
||||
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
|
||||
pipeline.run(module.operation)
|
||||
@ -357,14 +364,16 @@ def _lower_tpu_kernel(
|
||||
dump_mlir(module, "post-canonicalize-mosaic")
|
||||
|
||||
pipeline = [
|
||||
"func.func(tpu-infer-vector-layout{sublane-count=8 lane-count=128})",
|
||||
(
|
||||
"func.func(tpu-infer-vector-layout{"
|
||||
f" sublane-count={sl_cnt} lane-count={l_cnt}"
|
||||
"})"
|
||||
),
|
||||
]
|
||||
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
|
||||
pipeline.run(module.operation)
|
||||
dump_mlir(module, "post-infer-vector-layout")
|
||||
|
||||
sl_cnt = 8
|
||||
l_cnt = 128
|
||||
mxu_size = 128 if hardware_generation < 6 else 256
|
||||
pipeline = [
|
||||
"func.func(tpu-apply-vector-layout{"
|
||||
@ -414,7 +423,10 @@ def _lower_mosaic_module_to_asm(
|
||||
"tpu_custom_call cannot be lowered on a machine without TPUs "
|
||||
"when mosaic_use_python_pipeline=True.")
|
||||
hardware_generation = int(device_kind[len("TPU v")])
|
||||
module = _lower_tpu_kernel(module, hardware_generation)
|
||||
# TODO(b/369418606): Infer the target shape from the hardware generation.
|
||||
module = _lower_tpu_kernel(
|
||||
module, hardware_generation, target_shape=(8, 128)
|
||||
)
|
||||
needs_hlo_passes = False
|
||||
needs_layout_passes = False
|
||||
prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects
|
||||
|
@ -38,18 +38,8 @@ class TPU_Attr<string name, string mnemonic_, list<Trait> traits = []>
|
||||
let mnemonic = mnemonic_;
|
||||
}
|
||||
|
||||
def TPU_Vreg : Type<
|
||||
And<[IsVectorTypePred,
|
||||
Or<[
|
||||
And<[
|
||||
CPred<"llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{8, 128}">,
|
||||
CPred<"llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth() == 32">
|
||||
]>,
|
||||
CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{"
|
||||
"8, 128, 32 / ::llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth()}">,
|
||||
]>
|
||||
]>,
|
||||
"native-sized vreg", "::mlir::VectorType">;
|
||||
// TODO(b/369418606): Find out the way to verify vreg size.
|
||||
def TPU_Vreg : Type<IsVectorTypePred, "native-sized vreg", "::mlir::VectorType">;
|
||||
|
||||
class TPU_Type<string name, string mnemonic_, list<Trait> traits = []>
|
||||
: TypeDef<TPU_Dialect, name, traits> {
|
||||
@ -738,6 +728,8 @@ def InferMemRefLayoutPass : Pass<"tpu-infer-memref-layout", "::mlir::func::FuncO
|
||||
// If hardware_generation is not set, the default value of -1 will crash on
|
||||
// runOnOperation.
|
||||
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
|
||||
Option<"lane_count", "lane-count", "int", /*default=*/"128", "">,
|
||||
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
|
||||
Option<"tpu_tiling_flags", "tpu-tiling-flags", "::mlir::tpu::TpuTilingFlags", /*default=*/"::mlir::tpu::TpuTilingFlags{}", "">,
|
||||
];
|
||||
}
|
||||
|
@ -68,13 +68,15 @@ struct ApplyVectorLayoutContext {
|
||||
std::pair<bool, bool> mightCommunicateBetweenChips(Operation* op);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
|
||||
int hardware_generation = -1, const TpuTilingFlags &tpu_tiling_flags = {});
|
||||
int hardware_generation = -1,
|
||||
std::array<int64_t, 2> target_shape = {8, 128},
|
||||
const TpuTilingFlags &tpu_tiling_flags = {});
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass(
|
||||
int hardware_generation = -1);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
|
||||
int lane_count = 128, int sublane_count = 8);
|
||||
std::array<int64_t, 2> target_shape = {8, 128});
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
|
||||
const ApplyVectorLayoutContext &ctx = ApplyVectorLayoutContext{});
|
||||
|
@ -164,7 +164,7 @@ FailureOr<TypedValue<MemRefType>> getInternalScratch(
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
MemRefType scratch_ref_ty,
|
||||
inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation,
|
||||
/*tpu_tiling_flags=*/{}, sublane_tiling));
|
||||
ctx.target_shape, /*tpu_tiling_flags=*/{}, sublane_tiling));
|
||||
return builder.create<tpu::GetInternalScratchOp>(loc, scratch_ref_ty)
|
||||
.getResult();
|
||||
}
|
||||
@ -490,7 +490,7 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx, func::FuncOp func,
|
||||
MemRefType arg_type,
|
||||
inferMemref(
|
||||
MemRefType::get(value_ty.getShape(), value_ty.getElementType()),
|
||||
ctx.hardware_generation, /*tpu_tiling_flags=*/{}));
|
||||
ctx.hardware_generation, ctx.target_shape, /*tpu_tiling_flags=*/{}));
|
||||
const BlockArgument argument = entry_block.insertArgument(
|
||||
entry_block.getNumArguments() - 1, arg_type, UnknownLoc::get(mlir_ctx));
|
||||
const FunctionType func_ty = func.getFunctionType();
|
||||
@ -5821,8 +5821,8 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
if (try_replicate_rows && packing == 1 &&
|
||||
*(vregs.dimensions().end() - 2) == 1 &&
|
||||
src.offsets() == LayoutOffsets{0, 0} &&
|
||||
src.tiling() == std::array<int64_t, 2>{1, 128} &&
|
||||
dst_tiling == std::array<int64_t, 2>{8, 128}) {
|
||||
src.tiling() == std::array<int64_t, 2>{1, ctx.target_shape[1]} &&
|
||||
dst_tiling == ctx.target_shape) {
|
||||
xla::Array<Value> retiled(dst_tiles_shape);
|
||||
retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
|
||||
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
|
||||
@ -5839,9 +5839,9 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
return std::pair(dst, std::move(retiled));
|
||||
}
|
||||
// (8,128) -> (8 * packing,128) tiling change for packed type.
|
||||
if (bitwidth < 32 && 32 % bitwidth == 0 &&
|
||||
src_tiling == std::array<int64_t, 2>{8, 128} &&
|
||||
dst_tiling == std::array<int64_t, 2>{8 * dst.packing(), 128}) {
|
||||
if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
|
||||
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * dst.packing(),
|
||||
ctx.target_shape[1]}) {
|
||||
// Note: for int4, retiling with scratch is always faster.
|
||||
if (bitwidth != 4 || !has_enough_scratch) {
|
||||
xla::Array<Value> retiled(dst_tiles_shape);
|
||||
@ -5883,8 +5883,8 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
|
||||
// match corresponding elements without shifting. It's just that
|
||||
// the tiles are not adjacent (no contiguous vreg slice).
|
||||
if (bitwidth < 32 && 32 % bitwidth == 0 &&
|
||||
src_tiling == std::array<int64_t, 2>{1, 128 * packing} &&
|
||||
dst_tiling == std::array<int64_t, 2>{packing, 128}) {
|
||||
src_tiling == std::array<int64_t, 2>{1, ctx.target_shape[1] * packing} &&
|
||||
dst_tiling == std::array<int64_t, 2>{packing, ctx.target_shape[1]}) {
|
||||
// To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of
|
||||
// 4 sublanes and 2 lanes (this is convenient for to keep the example small
|
||||
// yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling.
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
@ -33,16 +34,17 @@ namespace mlir::tpu {
|
||||
#define GEN_PASS_DEF_INFERMEMREFLAYOUTPASS
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
|
||||
|
||||
|
||||
// Returns the number of 128-element groups in a tile.
|
||||
// Returns the number of lanes (usually 128) groups in a tile.
|
||||
//
|
||||
// Arguments:
|
||||
// num_128s: A number of 128-element groups in the full operand.
|
||||
// num_lanes: A number of lanes in the full operand.
|
||||
// hardware_generation: An integer indicating the target TPU generation.
|
||||
// sublane_count: The number of sublanes.
|
||||
// tpu_tiling_flags: A struct of flags indicating which large tiling modes are
|
||||
// enabled by XLA for memrefs.
|
||||
// bitwidth: The bitwidth of the element type of the operand.
|
||||
int getTilingFactor(const int num_128s, const int hardware_generation,
|
||||
int getTilingFactor(const int num_lanes, const int hardware_generation,
|
||||
const int64_t sublane_count,
|
||||
const TpuTilingFlags &tpu_tiling_flags,
|
||||
const int8_t bitwidth) {
|
||||
CHECK(llvm::isPowerOf2_32(bitwidth));
|
||||
@ -50,29 +52,29 @@ int getTilingFactor(const int num_128s, const int hardware_generation,
|
||||
CHECK_LE(bitwidth, 32);
|
||||
const int packing = 32 / bitwidth;
|
||||
const int min_tiling = (1 + (hardware_generation < 4)) * packing;
|
||||
const int max_normal_tiling = 8;
|
||||
const int max_normal_tiling = sublane_count;
|
||||
|
||||
const int large_tiling = [&] {
|
||||
if (bitwidth == 4 && tpu_tiling_flags.use_x4_large_second_minor) {
|
||||
return 64;
|
||||
return sublane_count * 8;
|
||||
}
|
||||
if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) {
|
||||
return 32;
|
||||
return sublane_count * 4;
|
||||
}
|
||||
if (bitwidth == 16 && tpu_tiling_flags.use_x16_large_second_minor) {
|
||||
return 16;
|
||||
return sublane_count * 2;
|
||||
}
|
||||
return 8;
|
||||
return sublane_count;
|
||||
}();
|
||||
|
||||
// Use large tiling if our operand is tall enough to fit at least one full
|
||||
// tile.
|
||||
if (large_tiling <= num_128s) {
|
||||
if (large_tiling <= num_lanes) {
|
||||
return large_tiling;
|
||||
}
|
||||
|
||||
int tiling = min_tiling;
|
||||
while (tiling < std::min(num_128s, max_normal_tiling)) {
|
||||
while (tiling < std::min(num_lanes, max_normal_tiling)) {
|
||||
tiling *= 2;
|
||||
}
|
||||
return tiling;
|
||||
@ -80,6 +82,7 @@ int getTilingFactor(const int num_128s, const int hardware_generation,
|
||||
|
||||
FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
|
||||
const int hardware_generation,
|
||||
std::array<int64_t, 2> target_shape,
|
||||
const TpuTilingFlags &tpu_tiling_flags,
|
||||
int64_t leading_tile_rows = 0) {
|
||||
if (auto tiled_layout_attr =
|
||||
@ -100,12 +103,14 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
|
||||
"Invalid element type for memref");
|
||||
}
|
||||
const int8_t bitwidth = memref_ty.getElementTypeBitWidth();
|
||||
const auto [sublane_count, lane_count] = target_shape;
|
||||
// Infer the layout
|
||||
if (memref_ty.getRank() == 1) {
|
||||
const int64_t leading_tile =
|
||||
getTilingFactor(llvm::divideCeil(memref_ty.getShape().back(), 128),
|
||||
hardware_generation, tpu_tiling_flags, bitwidth) *
|
||||
128;
|
||||
getTilingFactor(
|
||||
llvm::divideCeil(memref_ty.getShape().back(), lane_count),
|
||||
hardware_generation, sublane_count, tpu_tiling_flags, bitwidth) *
|
||||
lane_count;
|
||||
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile})};
|
||||
if (bitwidth != 32) {
|
||||
if (!llvm::has_single_bit<unsigned>(bitwidth) || bitwidth > 32) {
|
||||
@ -113,7 +118,7 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
|
||||
"Unsupported bitwidth: ")
|
||||
<< bitwidth;
|
||||
}
|
||||
tiles.append({xla::Tile({128}), xla::Tile({32 / bitwidth, 1})});
|
||||
tiles.append({xla::Tile({lane_count}), xla::Tile({32 / bitwidth, 1})});
|
||||
}
|
||||
return TiledLayoutAttr::get(memref_ty.getContext(), tiles, {1});
|
||||
}
|
||||
@ -122,10 +127,11 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
|
||||
const ArrayRef<int64_t> shape = memref_ty.getShape();
|
||||
const int64_t second_minor = shape[shape.size() - 2];
|
||||
if (leading_tile_rows == 0) {
|
||||
leading_tile_rows = getTilingFactor(second_minor, hardware_generation,
|
||||
tpu_tiling_flags, bitwidth);
|
||||
leading_tile_rows =
|
||||
getTilingFactor(second_minor, hardware_generation, sublane_count,
|
||||
tpu_tiling_flags, bitwidth);
|
||||
}
|
||||
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile_rows, 128})};
|
||||
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile_rows, lane_count})};
|
||||
if (bitwidth != 32) {
|
||||
if (!llvm::has_single_bit<unsigned>(bitwidth) || bitwidth > 32) {
|
||||
return emitError(UnknownLoc::get(memref_ty.getContext()),
|
||||
@ -134,7 +140,8 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
|
||||
}
|
||||
tiles.push_back(xla::Tile({32 / bitwidth, 1}));
|
||||
}
|
||||
auto tile_strides = ComputeTileStrides(memref_ty, {leading_tile_rows, 128});
|
||||
auto tile_strides =
|
||||
ComputeTileStrides(memref_ty, {leading_tile_rows, lane_count});
|
||||
return TiledLayoutAttr::get(memref_ty.getContext(), tiles, tile_strides);
|
||||
}
|
||||
return emitError(UnknownLoc::get(memref_ty.getContext()),
|
||||
@ -167,6 +174,7 @@ LogicalResult checkTiles(MLIRContext *mlir_ctx,
|
||||
|
||||
FailureOr<MemRefType> inferMemref(MemRefType memref,
|
||||
const int hardware_generation,
|
||||
std::array<int64_t, 2> target_shape,
|
||||
const TpuTilingFlags &tpu_tiling_flags,
|
||||
int64_t leading_tile_rows) {
|
||||
if (isa<SemaphoreType, DMASemaphoreType>(memref.getElementType())) {
|
||||
@ -188,9 +196,10 @@ FailureOr<MemRefType> inferMemref(MemRefType memref,
|
||||
tpu::MemorySpaceAttr::get(memref.getContext(), MemorySpace::vmem);
|
||||
const Attribute memory_space =
|
||||
memref.getMemorySpace() == nullptr ? vmem : memref.getMemorySpace();
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const TiledLayoutAttr layout,
|
||||
inferLayout(memref, hardware_generation,
|
||||
tpu_tiling_flags, leading_tile_rows));
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const TiledLayoutAttr layout,
|
||||
inferLayout(memref, hardware_generation, target_shape, tpu_tiling_flags,
|
||||
leading_tile_rows));
|
||||
|
||||
const ArrayRef<xla::Tile> tiles = layout.getTiles();
|
||||
if (failed(checkTiles(memref.getContext(), tiles))) {
|
||||
@ -212,13 +221,14 @@ FailureOr<MemRefType> inferMemref(MemRefType memref,
|
||||
}
|
||||
|
||||
LogicalResult inferOp(Operation &op, const int hardware_generation,
|
||||
std::array<int64_t, 2> target_shape,
|
||||
const TpuTilingFlags &tpu_tiling_flags) {
|
||||
if (auto alloca_op = dyn_cast<memref::AllocaOp>(op)) {
|
||||
TypedValue<MemRefType> arg = alloca_op.getResult();
|
||||
const MemRefType memref_ty = alloca_op.getResult().getType();
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const MemRefType new_memref_ty,
|
||||
inferMemref(memref_ty, hardware_generation, tpu_tiling_flags));
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty,
|
||||
inferMemref(memref_ty, hardware_generation,
|
||||
target_shape, tpu_tiling_flags));
|
||||
alloca_op.getResult().setType(new_memref_ty);
|
||||
if (memref_ty != new_memref_ty) {
|
||||
OpBuilder builder(alloca_op->getContext());
|
||||
@ -233,9 +243,9 @@ LogicalResult inferOp(Operation &op, const int hardware_generation,
|
||||
} else if (auto alloca_op = dyn_cast<tpu::AllocaSemaphoreOp>(op)) {
|
||||
TypedValue<MemRefType> arg = alloca_op.getResult();
|
||||
const MemRefType memref_ty = alloca_op.getResult().getType();
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const MemRefType new_memref_ty,
|
||||
inferMemref(memref_ty, hardware_generation, tpu_tiling_flags));
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty,
|
||||
inferMemref(memref_ty, hardware_generation,
|
||||
target_shape, tpu_tiling_flags));
|
||||
alloca_op.getResult().setType(new_memref_ty);
|
||||
if (memref_ty != new_memref_ty) {
|
||||
OpBuilder builder(alloca_op->getContext());
|
||||
@ -251,7 +261,8 @@ LogicalResult inferOp(Operation &op, const int hardware_generation,
|
||||
for (Region ®ion : op.getRegions()) {
|
||||
for (Block &block : region) {
|
||||
for (Operation& op : block) {
|
||||
if (failed(inferOp(op, hardware_generation, tpu_tiling_flags))) {
|
||||
if (failed(inferOp(op, hardware_generation, target_shape,
|
||||
tpu_tiling_flags))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
@ -261,6 +272,7 @@ LogicalResult inferOp(Operation &op, const int hardware_generation,
|
||||
}
|
||||
|
||||
LogicalResult inferFunc(func::FuncOp f, const int hardware_generation,
|
||||
std::array<int64_t, 2> target_shape,
|
||||
const TpuTilingFlags &tpu_tiling_flags) {
|
||||
if (!f.getBody().hasOneBlock()) {
|
||||
return f.emitOpError("Functions should only have a single block");
|
||||
@ -285,8 +297,8 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation,
|
||||
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const MemRefType new_memref_ty,
|
||||
inferMemref(memref_ty, hardware_generation, tpu_tiling_flags,
|
||||
leading_tile_rows));
|
||||
inferMemref(memref_ty, hardware_generation, target_shape,
|
||||
tpu_tiling_flags, leading_tile_rows));
|
||||
arg.setType(new_memref_ty);
|
||||
new_arg_types.push_back(arg.getType());
|
||||
if (memref_ty != new_memref_ty) {
|
||||
@ -305,30 +317,8 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation,
|
||||
f.setFunctionType(
|
||||
builder.getAttr<FunctionType>(new_arg_types, f.getResultTypes()));
|
||||
for (Operation &op : entry.getOperations()) {
|
||||
if (failed(inferOp(op, hardware_generation, tpu_tiling_flags))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Infers the layout and memory space attributes of function memref arguments.
|
||||
//
|
||||
// In the future we should require those annotations from Mosaic users, but it's
|
||||
// best to keep them internal for as long as they are under development.
|
||||
//
|
||||
// Arguments:
|
||||
// module: The MLIR module on which to perform the inference.
|
||||
// hardware_generation: The TPU hardware generation to target.
|
||||
LogicalResult inferModule(ModuleOp module, const int hardware_generation,
|
||||
const TpuTilingFlags &tpu_tiling_flags) {
|
||||
// TODO(apaszke): Do layout assignment for scoped allocations too.
|
||||
for (Operation &op : *module.getBody()) {
|
||||
auto f = dyn_cast<func::FuncOp>(op);
|
||||
if (f == nullptr) {
|
||||
return module.emitOpError("Expected only FuncOps but found ") << op;
|
||||
}
|
||||
if (failed(inferFunc(f, hardware_generation, tpu_tiling_flags))) {
|
||||
if (failed(
|
||||
inferOp(op, hardware_generation, target_shape, tpu_tiling_flags))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
@ -338,8 +328,11 @@ LogicalResult inferModule(ModuleOp module, const int hardware_generation,
|
||||
struct InferMemRefLayoutPass
|
||||
: public impl::InferMemRefLayoutPassBase<InferMemRefLayoutPass> {
|
||||
InferMemRefLayoutPass(int hardware_generation_,
|
||||
std::array<int64_t, 2> target_shape_,
|
||||
const TpuTilingFlags &tpu_tiling_flags_) {
|
||||
hardware_generation = hardware_generation_;
|
||||
sublane_count = target_shape_[0];
|
||||
lane_count = target_shape_[1];
|
||||
tpu_tiling_flags = tpu_tiling_flags_;
|
||||
}
|
||||
void runOnOperation() override {
|
||||
@ -349,7 +342,8 @@ struct InferMemRefLayoutPass
|
||||
return;
|
||||
}
|
||||
func::FuncOp func = getOperation();
|
||||
if (failed(inferFunc(func, hardware_generation, tpu_tiling_flags))) {
|
||||
if (failed(inferFunc(func, hardware_generation, {sublane_count, lane_count},
|
||||
tpu_tiling_flags))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@ -357,9 +351,10 @@ struct InferMemRefLayoutPass
|
||||
};
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
|
||||
int hardware_generation, const TpuTilingFlags &tpu_tiling_flags_) {
|
||||
return std::make_unique<InferMemRefLayoutPass>(hardware_generation,
|
||||
tpu_tiling_flags_);
|
||||
int hardware_generation, std::array<int64_t, 2> target_shape,
|
||||
const TpuTilingFlags &tpu_tiling_flags_) {
|
||||
return std::make_unique<InferMemRefLayoutPass>(
|
||||
hardware_generation, target_shape, tpu_tiling_flags_);
|
||||
}
|
||||
|
||||
} // namespace mlir::tpu
|
||||
|
@ -1,7 +1,9 @@
|
||||
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_
|
||||
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_INFER_MEMREF_LAYOUT_H_
|
||||
|
||||
#include <string>
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <string_view>
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
@ -10,6 +12,7 @@
|
||||
namespace mlir::tpu {
|
||||
|
||||
FailureOr<MemRefType> inferMemref(MemRefType memref, int hardware_generation,
|
||||
std::array<int64_t, 2> target_shape,
|
||||
const TpuTilingFlags& tpu_tiling_flags,
|
||||
int64_t leading_tile_rows = 0);
|
||||
|
||||
|
@ -21,9 +21,7 @@ limitations under the License.
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVectorExtras.h"
|
||||
@ -34,7 +32,6 @@ limitations under the License.
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
@ -42,6 +39,7 @@ limitations under the License.
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/include/mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
@ -49,6 +47,7 @@ limitations under the License.
|
||||
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
|
||||
#include "mlir/include/mlir/IR/OpDefinition.h"
|
||||
#include "mlir/include/mlir/IR/Visitors.h"
|
||||
#include "mlir/include/mlir/Pass/Pass.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/layout.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
#include "xla/layout.h"
|
||||
@ -2043,9 +2042,9 @@ class VectorLayoutInferer {
|
||||
|
||||
struct InferVectorLayoutPass
|
||||
: public impl::InferVectorLayoutPassBase<InferVectorLayoutPass> {
|
||||
InferVectorLayoutPass(int lane_count, int sublane_count) {
|
||||
this->sublane_count = sublane_count;
|
||||
this->lane_count = lane_count;
|
||||
InferVectorLayoutPass(std::array<int64_t, 2> target_shape) {
|
||||
this->sublane_count = target_shape[0];
|
||||
this->lane_count = target_shape[1];
|
||||
}
|
||||
void runOnOperation() override {
|
||||
func::FuncOp func = getOperation();
|
||||
@ -2059,8 +2058,8 @@ struct InferVectorLayoutPass
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
|
||||
int lane_count, int sublane_count) {
|
||||
return std::make_unique<InferVectorLayoutPass>(lane_count, sublane_count);
|
||||
std::array<int64_t, 2> target_shape) {
|
||||
return std::make_unique<InferVectorLayoutPass>(target_shape);
|
||||
}
|
||||
|
||||
} // namespace mlir::tpu
|
||||
|
Loading…
x
Reference in New Issue
Block a user