[Mosaic][TPU] Add a compatibility mode to Mosaic's canonicalization pass, skipping over elementwise and matmul op insertions and/or type compat casts.

PiperOrigin-RevId: 714132282
This commit is contained in:
jax authors 2025-01-10 12:10:13 -08:00
parent 743872dfed
commit a16fbffc13
4 changed files with 68 additions and 31 deletions

View File

@ -367,8 +367,12 @@ def _lower_tpu_kernel(
f"Unrecognized on-device check categories: {', '.join(checks)}"
)
# Legacy pipeline always runs in compatibility mode.
compatibility_mode = True
pipeline = [
f"func.func(tpu-canonicalize-mosaic{{hardware-generation={hardware_generation}}})",
(
f"func.func(tpu-canonicalize-mosaic{{hardware-generation={hardware_generation} compatibility-mode={compatibility_mode}}})"
),
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)

View File

@ -833,6 +833,7 @@ def CanonicalizeMosaicPass : Pass<"tpu-canonicalize-mosaic", "::mlir::func::Func
let constructor = "::mlir::tpu::createCanonicalizeMosaicPass()";
let options = [
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
Option<"compatibility_mode", "compatibility-mode", "bool", /*default=*/"1", "">,
];
}

View File

@ -76,7 +76,7 @@ std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
const TpuTilingFlags &tpu_tiling_flags = {});
std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass(
int hardware_generation = -1);
int hardware_generation = -1, bool compatibility_mode = true);
std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
int hardware_generation = -1,

View File

@ -49,7 +49,15 @@ namespace mlir::tpu {
namespace {
LogicalResult tpu_matmul_rule(tpu::MatmulOp op) {
struct CanonicalizeContext {
// see Note: Compatibility mode
bool compatibility_mode;
int hardware_generation;
};
LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx,
tpu::MatmulOp op) {
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
auto transpose_lhs = op.getTransposeLhs();
@ -134,6 +142,11 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) {
};
if (lhs_element_type != rhs_element_type) {
if (!ctx.compatibility_mode) {
return op->emitOpError(
"Mosaic matmul invoked with mixed element types, but compatibility "
"mode is disabled.");
}
if (lhs_element_type.isInteger() && rhs_element_type.isInteger()) {
// TODO(mvoz): Add support for mixed int/int matmul.
op->emitOpError("Mix int/int - NYI");
@ -264,7 +277,7 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) {
return success();
};
LogicalResult canonicalize_elementwise(int hardware_generation_,
LogicalResult canonicalize_elementwise(const CanonicalizeContext &ctx,
Operation &op) {
OpBuilder builder(&op);
auto operands = op.getOperands();
@ -295,21 +308,28 @@ LogicalResult canonicalize_elementwise(int hardware_generation_,
return failure();
}
auto element_type = ty.getElementType();
// PowFOp and DivFOp do not seem to be supported in bf16 on later
// hardware.
// There's an annoying hodgepodge of elementwise ops that need to be
// rewritten to f32 on later hardware.
// TODO(mvoz): Look into (1) what it would take to support these ops
// natively on later hardware, and (2) how to better organize this list.
bool needs_cast = hardware_generation_ <= 5 || isa<math::PowFOp>(op) ||
bool needs_cast = ctx.hardware_generation <= 5 || isa<math::PowFOp>(op) ||
isa<math::TanhOp>(op) || isa<math::ExpOp>(op) ||
isa<math::LogOp>(op);
if (needs_cast && element_type.isBF16()) {
auto target_f32 =
builder.create<arith::ExtFOp>(op.getLoc(), target_f32_ty, operand)
.getResult();
should_rewrite_op = true;
new_operands.push_back(target_f32);
if (ctx.compatibility_mode) {
auto target_f32 =
builder.create<arith::ExtFOp>(op.getLoc(), target_f32_ty, operand)
.getResult();
should_rewrite_op = true;
new_operands.push_back(target_f32);
} else {
op.emitOpError(
"Compatibility mode disabled. Unsupported element type in "
"elementwise op on hardware generation: ")
<< ctx.hardware_generation
<< ". Use hardware generation after 5 or cast to f32.";
return failure();
}
} else {
new_operands.push_back(operand);
}
@ -341,7 +361,7 @@ LogicalResult canonicalize_elementwise(int hardware_generation_,
return success();
}
LogicalResult canonicalize_multi_dim_reduction(int hardware_generation,
LogicalResult canonicalize_multi_dim_reduction(const CanonicalizeContext &ctx,
Operation &operation) {
ImplicitLocOpBuilder builder(operation.getLoc(), &operation);
auto op = cast<vector::MultiDimReductionOp>(operation);
@ -361,7 +381,7 @@ LogicalResult canonicalize_multi_dim_reduction(int hardware_generation,
reduces_sublanes = true;
}
}
if (hardware_generation <= 5) {
if (ctx.hardware_generation <= 5) {
auto new_source = builder.create<arith::ExtFOp>(
VectorType::get(source_ty.getShape(), builder.getF32Type()),
op.getSource());
@ -401,16 +421,18 @@ LogicalResult canonicalize_multi_dim_reduction(int hardware_generation,
return failure();
}
LogicalResult canonicalize_matmul(int hardware_generation, Operation &op) {
LogicalResult canonicalize_matmul(const CanonicalizeContext &ctx,
Operation &op) {
auto matmul_op = dyn_cast<tpu::MatmulOp>(op);
if (!matmul_op) {
op.emitOpError("Invariant violated: Not a matmul");
return failure();
}
return tpu_matmul_rule(matmul_op);
return tpu_matmul_rule(ctx, matmul_op);
};
LogicalResult canonicalize_contraction(int hardware_generation, Operation &op) {
LogicalResult canonicalize_contraction(const CanonicalizeContext &ctx,
Operation &op) {
auto contraction_op = dyn_cast<vector::ContractionOp>(op);
if (!contraction_op) {
op.emitOpError("Invariant violated: Not a contraction");
@ -478,11 +500,12 @@ LogicalResult canonicalize_contraction(int hardware_generation, Operation &op) {
/*transpose_rhs=*/false, precision_attr, dot_dimension_numbers_attr);
contraction_op.replaceAllUsesWith(matmul_op.getResult());
contraction_op.erase();
auto result = tpu_matmul_rule(matmul_op);
auto result = tpu_matmul_rule(ctx, matmul_op);
return result;
}
LogicalResult canonicalize_extract(int hardware_generation, Operation &raw_op) {
LogicalResult canonicalize_extract(const CanonicalizeContext &ctx,
Operation &raw_op) {
auto op = dyn_cast<vector::ExtractOp>(raw_op);
Type result_ty = op.getResult().getType();
if (!isa<VectorType>(result_ty)) {
@ -497,7 +520,8 @@ LogicalResult canonicalize_extract(int hardware_generation, Operation &raw_op) {
return success();
}
LogicalResult canonicalize_select(int hardware_generation, Operation &raw_op) {
LogicalResult canonicalize_select(const CanonicalizeContext &ctx,
Operation &raw_op) {
auto op = dyn_cast<arith::SelectOp>(raw_op);
if (!isa<VectorType>(op.getType()) ||
isa<VectorType>(op.getCondition().getType())) {
@ -515,7 +539,8 @@ LogicalResult canonicalize_select(int hardware_generation, Operation &raw_op) {
return success();
}
LogicalResult canonicalize_repeat(int hardware_generation, Operation &raw_op) {
LogicalResult canonicalize_repeat(const CanonicalizeContext &ctx,
Operation &raw_op) {
auto op = dyn_cast<tpu::RepeatOp>(raw_op);
if (!isa<VectorType>(op.getType())) {
return op.emitOpError("Only vector types supported");
@ -539,7 +564,7 @@ LogicalResult canonicalize_repeat(int hardware_generation, Operation &raw_op) {
}
using canonicalize_rule_type =
std::function<LogicalResult(int hardware_generation, Operation &op)>;
std::function<LogicalResult(const CanonicalizeContext &ctx, Operation &op)>;
const llvm::StringMap<canonicalize_rule_type> &rules() {
static auto rules = new llvm::StringMap<canonicalize_rule_type>{
@ -569,10 +594,12 @@ const llvm::StringSet<> &elementwise_convertible_ops() {
class MosaicCanonicalizer {
public:
MosaicCanonicalizer(int hardware_generation)
: hardware_generation_(hardware_generation) {}
MosaicCanonicalizer(int hardware_generation, bool compatibility_mode)
: hardware_generation_(hardware_generation),
compatibility_mode_(compatibility_mode) {}
int hardware_generation_;
bool compatibility_mode_;
LogicalResult canonicalize(func::FuncOp op) {
if (!op.getBody().hasOneBlock()) {
@ -593,6 +620,7 @@ class MosaicCanonicalizer {
}
LogicalResult canonicalizeOp(Operation &any_op) {
CanonicalizeContext ctx({compatibility_mode_, hardware_generation_});
// We must iterate over the op first, because canonicalization can cause
// us to .erase() an op, and accessing getRegions on it after is not sound.
// Invariant - top level ops with regions may never be invalidated.
@ -605,12 +633,12 @@ class MosaicCanonicalizer {
}
if (elementwise_convertible_ops().contains(
any_op.getName().getStringRef())) {
return canonicalize_elementwise(hardware_generation_, any_op);
return canonicalize_elementwise(ctx, any_op);
}
if (auto rule_it = rules().find(any_op.getName().getStringRef());
rule_it != rules().end()) {
const canonicalize_rule_type &rule = rule_it->getValue();
return rule(hardware_generation_, any_op);
return rule(ctx, any_op);
}
return success();
}
@ -618,24 +646,28 @@ class MosaicCanonicalizer {
struct CanonicalizeMosaicPass
: public impl::CanonicalizeMosaicPassBase<CanonicalizeMosaicPass> {
CanonicalizeMosaicPass(int hardware_generation) {
this->hardware_generation = hardware_generation;
CanonicalizeMosaicPass(int hardware_generation_p, bool compatibility_mode_p)
: compatibility_mode_(compatibility_mode_p) {
this->hardware_generation = hardware_generation_p;
}
void runOnOperation() override {
func::FuncOp func = getOperation();
MosaicCanonicalizer vlc(hardware_generation);
MosaicCanonicalizer vlc(hardware_generation, compatibility_mode_);
if (vlc.canonicalize(func).failed()) {
signalPassFailure();
}
};
bool compatibility_mode_;
};
} // namespace
std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass(
int hardware_generation) {
return std::make_unique<CanonicalizeMosaicPass>(hardware_generation);
int hardware_generation, bool compatibility_mode) {
return std::make_unique<CanonicalizeMosaicPass>(hardware_generation,
compatibility_mode);
}
} // namespace mlir::tpu