mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Mosaic][TPU] Add a compatibility mode to Mosaic's canonicalization pass, skipping over elementwise and matmul op insertions and/or type compat casts.
PiperOrigin-RevId: 714132282
This commit is contained in:
parent
743872dfed
commit
a16fbffc13
@ -367,8 +367,12 @@ def _lower_tpu_kernel(
|
||||
f"Unrecognized on-device check categories: {', '.join(checks)}"
|
||||
)
|
||||
|
||||
# Legacy pipeline always runs in compatibility mode.
|
||||
compatibility_mode = True
|
||||
pipeline = [
|
||||
f"func.func(tpu-canonicalize-mosaic{{hardware-generation={hardware_generation}}})",
|
||||
(
|
||||
f"func.func(tpu-canonicalize-mosaic{{hardware-generation={hardware_generation} compatibility-mode={compatibility_mode}}})"
|
||||
),
|
||||
]
|
||||
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
|
||||
pipeline.run(module.operation)
|
||||
|
@ -833,6 +833,7 @@ def CanonicalizeMosaicPass : Pass<"tpu-canonicalize-mosaic", "::mlir::func::Func
|
||||
let constructor = "::mlir::tpu::createCanonicalizeMosaicPass()";
|
||||
let options = [
|
||||
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
|
||||
Option<"compatibility_mode", "compatibility-mode", "bool", /*default=*/"1", "">,
|
||||
];
|
||||
}
|
||||
|
||||
|
@ -76,7 +76,7 @@ std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
|
||||
const TpuTilingFlags &tpu_tiling_flags = {});
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass(
|
||||
int hardware_generation = -1);
|
||||
int hardware_generation = -1, bool compatibility_mode = true);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
|
||||
int hardware_generation = -1,
|
||||
|
@ -49,7 +49,15 @@ namespace mlir::tpu {
|
||||
|
||||
namespace {
|
||||
|
||||
LogicalResult tpu_matmul_rule(tpu::MatmulOp op) {
|
||||
struct CanonicalizeContext {
|
||||
// see Note: Compatibility mode
|
||||
bool compatibility_mode;
|
||||
|
||||
int hardware_generation;
|
||||
};
|
||||
|
||||
LogicalResult tpu_matmul_rule(const CanonicalizeContext &ctx,
|
||||
tpu::MatmulOp op) {
|
||||
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
|
||||
|
||||
auto transpose_lhs = op.getTransposeLhs();
|
||||
@ -134,6 +142,11 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) {
|
||||
};
|
||||
|
||||
if (lhs_element_type != rhs_element_type) {
|
||||
if (!ctx.compatibility_mode) {
|
||||
return op->emitOpError(
|
||||
"Mosaic matmul invoked with mixed element types, but compatibility "
|
||||
"mode is disabled.");
|
||||
}
|
||||
if (lhs_element_type.isInteger() && rhs_element_type.isInteger()) {
|
||||
// TODO(mvoz): Add support for mixed int/int matmul.
|
||||
op->emitOpError("Mix int/int - NYI");
|
||||
@ -264,7 +277,7 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) {
|
||||
return success();
|
||||
};
|
||||
|
||||
LogicalResult canonicalize_elementwise(int hardware_generation_,
|
||||
LogicalResult canonicalize_elementwise(const CanonicalizeContext &ctx,
|
||||
Operation &op) {
|
||||
OpBuilder builder(&op);
|
||||
auto operands = op.getOperands();
|
||||
@ -295,21 +308,28 @@ LogicalResult canonicalize_elementwise(int hardware_generation_,
|
||||
return failure();
|
||||
}
|
||||
auto element_type = ty.getElementType();
|
||||
// PowFOp and DivFOp do not seem to be supported in bf16 on later
|
||||
// hardware.
|
||||
// There's an annoying hodgepodge of elementwise ops that need to be
|
||||
// rewritten to f32 on later hardware.
|
||||
// TODO(mvoz): Look into (1) what it would take to support these ops
|
||||
// natively on later hardware, and (2) how to better organize this list.
|
||||
bool needs_cast = hardware_generation_ <= 5 || isa<math::PowFOp>(op) ||
|
||||
bool needs_cast = ctx.hardware_generation <= 5 || isa<math::PowFOp>(op) ||
|
||||
isa<math::TanhOp>(op) || isa<math::ExpOp>(op) ||
|
||||
isa<math::LogOp>(op);
|
||||
if (needs_cast && element_type.isBF16()) {
|
||||
auto target_f32 =
|
||||
builder.create<arith::ExtFOp>(op.getLoc(), target_f32_ty, operand)
|
||||
.getResult();
|
||||
should_rewrite_op = true;
|
||||
new_operands.push_back(target_f32);
|
||||
if (ctx.compatibility_mode) {
|
||||
auto target_f32 =
|
||||
builder.create<arith::ExtFOp>(op.getLoc(), target_f32_ty, operand)
|
||||
.getResult();
|
||||
should_rewrite_op = true;
|
||||
new_operands.push_back(target_f32);
|
||||
} else {
|
||||
op.emitOpError(
|
||||
"Compatibility mode disabled. Unsupported element type in "
|
||||
"elementwise op on hardware generation: ")
|
||||
<< ctx.hardware_generation
|
||||
<< ". Use hardware generation after 5 or cast to f32.";
|
||||
return failure();
|
||||
}
|
||||
} else {
|
||||
new_operands.push_back(operand);
|
||||
}
|
||||
@ -341,7 +361,7 @@ LogicalResult canonicalize_elementwise(int hardware_generation_,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult canonicalize_multi_dim_reduction(int hardware_generation,
|
||||
LogicalResult canonicalize_multi_dim_reduction(const CanonicalizeContext &ctx,
|
||||
Operation &operation) {
|
||||
ImplicitLocOpBuilder builder(operation.getLoc(), &operation);
|
||||
auto op = cast<vector::MultiDimReductionOp>(operation);
|
||||
@ -361,7 +381,7 @@ LogicalResult canonicalize_multi_dim_reduction(int hardware_generation,
|
||||
reduces_sublanes = true;
|
||||
}
|
||||
}
|
||||
if (hardware_generation <= 5) {
|
||||
if (ctx.hardware_generation <= 5) {
|
||||
auto new_source = builder.create<arith::ExtFOp>(
|
||||
VectorType::get(source_ty.getShape(), builder.getF32Type()),
|
||||
op.getSource());
|
||||
@ -401,16 +421,18 @@ LogicalResult canonicalize_multi_dim_reduction(int hardware_generation,
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult canonicalize_matmul(int hardware_generation, Operation &op) {
|
||||
LogicalResult canonicalize_matmul(const CanonicalizeContext &ctx,
|
||||
Operation &op) {
|
||||
auto matmul_op = dyn_cast<tpu::MatmulOp>(op);
|
||||
if (!matmul_op) {
|
||||
op.emitOpError("Invariant violated: Not a matmul");
|
||||
return failure();
|
||||
}
|
||||
return tpu_matmul_rule(matmul_op);
|
||||
return tpu_matmul_rule(ctx, matmul_op);
|
||||
};
|
||||
|
||||
LogicalResult canonicalize_contraction(int hardware_generation, Operation &op) {
|
||||
LogicalResult canonicalize_contraction(const CanonicalizeContext &ctx,
|
||||
Operation &op) {
|
||||
auto contraction_op = dyn_cast<vector::ContractionOp>(op);
|
||||
if (!contraction_op) {
|
||||
op.emitOpError("Invariant violated: Not a contraction");
|
||||
@ -478,11 +500,12 @@ LogicalResult canonicalize_contraction(int hardware_generation, Operation &op) {
|
||||
/*transpose_rhs=*/false, precision_attr, dot_dimension_numbers_attr);
|
||||
contraction_op.replaceAllUsesWith(matmul_op.getResult());
|
||||
contraction_op.erase();
|
||||
auto result = tpu_matmul_rule(matmul_op);
|
||||
auto result = tpu_matmul_rule(ctx, matmul_op);
|
||||
return result;
|
||||
}
|
||||
|
||||
LogicalResult canonicalize_extract(int hardware_generation, Operation &raw_op) {
|
||||
LogicalResult canonicalize_extract(const CanonicalizeContext &ctx,
|
||||
Operation &raw_op) {
|
||||
auto op = dyn_cast<vector::ExtractOp>(raw_op);
|
||||
Type result_ty = op.getResult().getType();
|
||||
if (!isa<VectorType>(result_ty)) {
|
||||
@ -497,7 +520,8 @@ LogicalResult canonicalize_extract(int hardware_generation, Operation &raw_op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult canonicalize_select(int hardware_generation, Operation &raw_op) {
|
||||
LogicalResult canonicalize_select(const CanonicalizeContext &ctx,
|
||||
Operation &raw_op) {
|
||||
auto op = dyn_cast<arith::SelectOp>(raw_op);
|
||||
if (!isa<VectorType>(op.getType()) ||
|
||||
isa<VectorType>(op.getCondition().getType())) {
|
||||
@ -515,7 +539,8 @@ LogicalResult canonicalize_select(int hardware_generation, Operation &raw_op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult canonicalize_repeat(int hardware_generation, Operation &raw_op) {
|
||||
LogicalResult canonicalize_repeat(const CanonicalizeContext &ctx,
|
||||
Operation &raw_op) {
|
||||
auto op = dyn_cast<tpu::RepeatOp>(raw_op);
|
||||
if (!isa<VectorType>(op.getType())) {
|
||||
return op.emitOpError("Only vector types supported");
|
||||
@ -539,7 +564,7 @@ LogicalResult canonicalize_repeat(int hardware_generation, Operation &raw_op) {
|
||||
}
|
||||
|
||||
using canonicalize_rule_type =
|
||||
std::function<LogicalResult(int hardware_generation, Operation &op)>;
|
||||
std::function<LogicalResult(const CanonicalizeContext &ctx, Operation &op)>;
|
||||
|
||||
const llvm::StringMap<canonicalize_rule_type> &rules() {
|
||||
static auto rules = new llvm::StringMap<canonicalize_rule_type>{
|
||||
@ -569,10 +594,12 @@ const llvm::StringSet<> &elementwise_convertible_ops() {
|
||||
|
||||
class MosaicCanonicalizer {
|
||||
public:
|
||||
MosaicCanonicalizer(int hardware_generation)
|
||||
: hardware_generation_(hardware_generation) {}
|
||||
MosaicCanonicalizer(int hardware_generation, bool compatibility_mode)
|
||||
: hardware_generation_(hardware_generation),
|
||||
compatibility_mode_(compatibility_mode) {}
|
||||
|
||||
int hardware_generation_;
|
||||
bool compatibility_mode_;
|
||||
|
||||
LogicalResult canonicalize(func::FuncOp op) {
|
||||
if (!op.getBody().hasOneBlock()) {
|
||||
@ -593,6 +620,7 @@ class MosaicCanonicalizer {
|
||||
}
|
||||
|
||||
LogicalResult canonicalizeOp(Operation &any_op) {
|
||||
CanonicalizeContext ctx({compatibility_mode_, hardware_generation_});
|
||||
// We must iterate over the op first, because canonicalization can cause
|
||||
// us to .erase() an op, and accessing getRegions on it after is not sound.
|
||||
// Invariant - top level ops with regions may never be invalidated.
|
||||
@ -605,12 +633,12 @@ class MosaicCanonicalizer {
|
||||
}
|
||||
if (elementwise_convertible_ops().contains(
|
||||
any_op.getName().getStringRef())) {
|
||||
return canonicalize_elementwise(hardware_generation_, any_op);
|
||||
return canonicalize_elementwise(ctx, any_op);
|
||||
}
|
||||
if (auto rule_it = rules().find(any_op.getName().getStringRef());
|
||||
rule_it != rules().end()) {
|
||||
const canonicalize_rule_type &rule = rule_it->getValue();
|
||||
return rule(hardware_generation_, any_op);
|
||||
return rule(ctx, any_op);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@ -618,24 +646,28 @@ class MosaicCanonicalizer {
|
||||
|
||||
struct CanonicalizeMosaicPass
|
||||
: public impl::CanonicalizeMosaicPassBase<CanonicalizeMosaicPass> {
|
||||
CanonicalizeMosaicPass(int hardware_generation) {
|
||||
this->hardware_generation = hardware_generation;
|
||||
CanonicalizeMosaicPass(int hardware_generation_p, bool compatibility_mode_p)
|
||||
: compatibility_mode_(compatibility_mode_p) {
|
||||
this->hardware_generation = hardware_generation_p;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
func::FuncOp func = getOperation();
|
||||
MosaicCanonicalizer vlc(hardware_generation);
|
||||
MosaicCanonicalizer vlc(hardware_generation, compatibility_mode_);
|
||||
if (vlc.canonicalize(func).failed()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
bool compatibility_mode_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass(
|
||||
int hardware_generation) {
|
||||
return std::make_unique<CanonicalizeMosaicPass>(hardware_generation);
|
||||
int hardware_generation, bool compatibility_mode) {
|
||||
return std::make_unique<CanonicalizeMosaicPass>(hardware_generation,
|
||||
compatibility_mode);
|
||||
}
|
||||
|
||||
} // namespace mlir::tpu
|
||||
|
Loading…
x
Reference in New Issue
Block a user