[MOSAIC] apply_vector_layout C++ rewrite (2) No-op pass and flag to use it instead of Python

PiperOrigin-RevId: 561697585
This commit is contained in:
Tomás Longeri 2023-08-31 10:41:14 -07:00 committed by jax authors
parent faa7a68422
commit 24c3a9dc79
4 changed files with 73 additions and 1 deletions

View File

@ -38,6 +38,12 @@ from jaxlib.mlir.dialects import stablehlo
from jaxlib.mlir.passmanager import PassManager
import numpy as np
config.define_bool_state(
name="use_cpp_apply_vector_layout",
default=False,
help="Use C++ implementation of apply vector layout pass (still a WIP)",
)
# TODO(sharadmv): remove when minimum jaxlib version is bumped to >= 0.4.14.
if tpu_mosaic is None:
raise ImportError("Cannot use Mosaic without a jaxlib >= 0.4.14.")
@ -240,10 +246,18 @@ def _lower_tpu_kernel(
module.operation.verify()
dump_mlir(module, "after infer vector layout pass")
apply_vector_layout.apply(module, hardware_generation)
if config.use_cpp_apply_vector_layout:
pipeline = [
"func.func(tpu-apply-vector-layout)",
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)
else:
apply_vector_layout.apply(module, hardware_generation)
module.operation.verify()
dump_mlir(module, "after apply vector layout pass")
PassManager.parse("builtin.module(canonicalize)").run(module.operation)
dump_mlir(module, "after final canonicalize pass")

View File

@ -425,6 +425,23 @@ def InferVectorLayoutPass : Pass<"tpu-infer-vector-layout", "::mlir::func::FuncO
];
}
def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncOp"> {
let dependentDialects = [
"::mlir::arith::ArithDialect",
"::mlir::func::FuncDialect",
"::mlir::vector::VectorDialect",
"::mlir::tpu::TPUDialect",
];
let constructor = "::mlir::tpu::createApplyVectorLayoutPass(-1)";
let options = [
// If hardware_generation is not set, the default value of -1 will crash on
// runOnOperation.
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
Option<"lane_count", "lane-count", "int", /*default=*/"128", "">,
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
];
}
def LinalgVectorizationPass : Pass<"linalg-vectorization", "::mlir::func::FuncOp"> {
let dependentDialects = [
"::mlir::func::FuncDialect",

View File

@ -49,6 +49,9 @@ std::pair<bool, bool> mightCommunicateBetweenChips(Operation* op);
std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
int lane_count = 128, int sublane_count = 8);
std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
int hardware_generation, int lane_count = 128, int sublane_count = 8);
std::unique_ptr<OperationPass<func::FuncOp>>
createLogicalToPhysicalDeviceIdPass(int64_t total_devices);

View File

@ -0,0 +1,38 @@
#include <memory>
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h" // NOLINT
#include "mlir/Pass/Pass.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
namespace mlir::tpu {
#define GEN_PASS_DECL_APPLYVECTORLAYOUTPASS
#define GEN_PASS_DEF_APPLYVECTORLAYOUTPASS
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
struct ApplyVectorLayoutPass
: public impl::ApplyVectorLayoutPassBase<ApplyVectorLayoutPass> {
ApplyVectorLayoutPass(int hardware_generation_, int lane_count_,
int sublane_count_) {
hardware_generation = hardware_generation_;
sublane_count = sublane_count_;
lane_count = lane_count_;
}
void runOnOperation() override {
// Fail if hardware_generation has not been set from the default value.
if (hardware_generation < 0) {
signalPassFailure();
return;
}
// No-op for now
}
};
std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
int hardware_generation, int lane_count, int sublane_count) {
return std::make_unique<ApplyVectorLayoutPass>(hardware_generation,
lane_count, sublane_count);
}
} // namespace mlir::tpu