mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[MOSAIC] apply_vector_layout C++ rewrite (2) No-op pass and flag to use it instead of Python
PiperOrigin-RevId: 561697585
This commit is contained in:
parent
faa7a68422
commit
24c3a9dc79
@ -38,6 +38,12 @@ from jaxlib.mlir.dialects import stablehlo
|
||||
from jaxlib.mlir.passmanager import PassManager
|
||||
import numpy as np
|
||||
|
||||
config.define_bool_state(
|
||||
name="use_cpp_apply_vector_layout",
|
||||
default=False,
|
||||
help="Use C++ implementation of apply vector layout pass (still a WIP)",
|
||||
)
|
||||
|
||||
# TODO(sharadmv): remove when minimum jaxlib version is bumped to >= 0.4.14.
|
||||
if tpu_mosaic is None:
|
||||
raise ImportError("Cannot use Mosaic without a jaxlib >= 0.4.14.")
|
||||
@ -240,10 +246,18 @@ def _lower_tpu_kernel(
|
||||
module.operation.verify()
|
||||
dump_mlir(module, "after infer vector layout pass")
|
||||
|
||||
apply_vector_layout.apply(module, hardware_generation)
|
||||
if config.use_cpp_apply_vector_layout:
|
||||
pipeline = [
|
||||
"func.func(tpu-apply-vector-layout)",
|
||||
]
|
||||
pipeline = PassManager.parse(f"builtin.module({','.join(pipeline)})")
|
||||
pipeline.run(module.operation)
|
||||
else:
|
||||
apply_vector_layout.apply(module, hardware_generation)
|
||||
module.operation.verify()
|
||||
dump_mlir(module, "after apply vector layout pass")
|
||||
|
||||
|
||||
PassManager.parse("builtin.module(canonicalize)").run(module.operation)
|
||||
dump_mlir(module, "after final canonicalize pass")
|
||||
|
||||
|
@ -425,6 +425,23 @@ def InferVectorLayoutPass : Pass<"tpu-infer-vector-layout", "::mlir::func::FuncO
|
||||
];
|
||||
}
|
||||
|
||||
def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncOp"> {
|
||||
let dependentDialects = [
|
||||
"::mlir::arith::ArithDialect",
|
||||
"::mlir::func::FuncDialect",
|
||||
"::mlir::vector::VectorDialect",
|
||||
"::mlir::tpu::TPUDialect",
|
||||
];
|
||||
let constructor = "::mlir::tpu::createApplyVectorLayoutPass(-1)";
|
||||
let options = [
|
||||
// If hardware_generation is not set, the default value of -1 will crash on
|
||||
// runOnOperation.
|
||||
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
|
||||
Option<"lane_count", "lane-count", "int", /*default=*/"128", "">,
|
||||
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
|
||||
];
|
||||
}
|
||||
|
||||
def LinalgVectorizationPass : Pass<"linalg-vectorization", "::mlir::func::FuncOp"> {
|
||||
let dependentDialects = [
|
||||
"::mlir::func::FuncDialect",
|
||||
|
@ -49,6 +49,9 @@ std::pair<bool, bool> mightCommunicateBetweenChips(Operation* op);
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
|
||||
int lane_count = 128, int sublane_count = 8);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
|
||||
int hardware_generation, int lane_count = 128, int sublane_count = 8);
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||
createLogicalToPhysicalDeviceIdPass(int64_t total_devices);
|
||||
|
||||
|
38
jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Normal file
38
jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Normal file
@ -0,0 +1,38 @@
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h" // NOLINT
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
|
||||
|
||||
namespace mlir::tpu {
|
||||
|
||||
#define GEN_PASS_DECL_APPLYVECTORLAYOUTPASS
|
||||
#define GEN_PASS_DEF_APPLYVECTORLAYOUTPASS
|
||||
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
|
||||
|
||||
struct ApplyVectorLayoutPass
|
||||
: public impl::ApplyVectorLayoutPassBase<ApplyVectorLayoutPass> {
|
||||
ApplyVectorLayoutPass(int hardware_generation_, int lane_count_,
|
||||
int sublane_count_) {
|
||||
hardware_generation = hardware_generation_;
|
||||
sublane_count = sublane_count_;
|
||||
lane_count = lane_count_;
|
||||
}
|
||||
void runOnOperation() override {
|
||||
// Fail if hardware_generation has not been set from the default value.
|
||||
if (hardware_generation < 0) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
// No-op for now
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<OperationPass<func::FuncOp>> createApplyVectorLayoutPass(
|
||||
int hardware_generation, int lane_count, int sublane_count) {
|
||||
return std::make_unique<ApplyVectorLayoutPass>(hardware_generation,
|
||||
lane_count, sublane_count);
|
||||
}
|
||||
|
||||
} // namespace mlir::tpu
|
Loading…
x
Reference in New Issue
Block a user