[Mosaic TPU] Add support for 1D windows

PiperOrigin-RevId: 657976726
This commit is contained in:
Adam Paszke 2024-07-31 05:57:39 -07:00 committed by jax authors
parent 4c13594bdd
commit 9dba6eb16a
3 changed files with 8 additions and 7 deletions

View File

@ -323,7 +323,7 @@ def _lower_tpu_kernel(
)
pipeline = [
"func.func(tpu-canonicalize-mosaic{})",
f"func.func(tpu-canonicalize-mosaic{{hardware-generation={hardware_generation}}})",
]
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
pipeline.run(module.operation)

View File

@ -711,7 +711,9 @@ def CanonicalizeMosaicPass : Pass<"tpu-canonicalize-mosaic", "::mlir::func::Func
"::mlir::tpu::TPUDialect",
];
let constructor = "::mlir::tpu::createCanonicalizeMosaicPass()";
let options = [];
let options = [
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
];
}
def InferVectorLayoutPass : Pass<"tpu-infer-vector-layout", "::mlir::func::FuncOp"> {

View File

@ -334,14 +334,13 @@ class MosaicCanonicalizer {
struct CanonicalizeMosaicPass
: public impl::CanonicalizeMosaicPassBase<CanonicalizeMosaicPass> {
CanonicalizeMosaicPass(int hardware_generation)
: hardware_generation_(hardware_generation) {}
int hardware_generation_;
CanonicalizeMosaicPass(int hardware_generation) {
this->hardware_generation = hardware_generation;
}
void runOnOperation() override {
func::FuncOp func = getOperation();
MosaicCanonicalizer vlc(hardware_generation_);
MosaicCanonicalizer vlc(hardware_generation);
if (vlc.canonicalize(func).failed()) {
signalPassFailure();
}