mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[Mosaic TPU] Add support for 1D windows
PiperOrigin-RevId: 657976726
This commit is contained in:
parent
4c13594bdd
commit
9dba6eb16a
@ -323,7 +323,7 @@ def _lower_tpu_kernel(
|
||||
)
|
||||
|
||||
pipeline = [
|
||||
"func.func(tpu-canonicalize-mosaic{})",
|
||||
f"func.func(tpu-canonicalize-mosaic{{hardware-generation={hardware_generation}}})",
|
||||
]
|
||||
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
|
||||
pipeline.run(module.operation)
|
||||
|
@ -711,7 +711,9 @@ def CanonicalizeMosaicPass : Pass<"tpu-canonicalize-mosaic", "::mlir::func::Func
|
||||
"::mlir::tpu::TPUDialect",
|
||||
];
|
||||
let constructor = "::mlir::tpu::createCanonicalizeMosaicPass()";
|
||||
let options = [];
|
||||
let options = [
|
||||
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
|
||||
];
|
||||
}
|
||||
|
||||
def InferVectorLayoutPass : Pass<"tpu-infer-vector-layout", "::mlir::func::FuncOp"> {
|
||||
|
@ -334,14 +334,13 @@ class MosaicCanonicalizer {
|
||||
|
||||
struct CanonicalizeMosaicPass
|
||||
: public impl::CanonicalizeMosaicPassBase<CanonicalizeMosaicPass> {
|
||||
CanonicalizeMosaicPass(int hardware_generation)
|
||||
: hardware_generation_(hardware_generation) {}
|
||||
|
||||
int hardware_generation_;
|
||||
CanonicalizeMosaicPass(int hardware_generation) {
|
||||
this->hardware_generation = hardware_generation;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
func::FuncOp func = getOperation();
|
||||
MosaicCanonicalizer vlc(hardware_generation_);
|
||||
MosaicCanonicalizer vlc(hardware_generation);
|
||||
if (vlc.canonicalize(func).failed()) {
|
||||
signalPassFailure();
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user