diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 875e182fa..3b8750f98 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -170,6 +170,17 @@ cc_library( ], ) +cc_library( + name = "pass_boilerplate", + hdrs = ["pass_boilerplate.h"], + # compatible with libtpu + deps = [ + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "handle_pool", hdrs = ["handle_pool.h"], diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index a50aef1ca..3061cd399 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -316,6 +316,7 @@ MlirContext getDefaultContext() { PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { mlirRegisterTPUPasses(); // Register all passes on load. + mlirTpuRegisterMosaicSerdePass(); py::class_(m, "ApplyVectorLayoutCtx", py::module_local()) diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index da7498ed4..238bf42d9 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -56,6 +56,7 @@ cc_library( # compatible with libtpu deps = [ ":tpu_inc_gen", + "//jaxlib:pass_boilerplate", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc index 3cc9b3697..ea05d67ca 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc @@ -44,6 +44,7 @@ limitations under the License. #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" #include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h" +#include "jaxlib/mosaic/dialect/tpu/transforms/serde.h" #include "xla/array.h" // TODO(tlongeri): null pointer checks? @@ -408,6 +409,10 @@ MlirValue mlirTpuRelayout(MlirTpuInsertionPoint insertion_point, MlirValue val, } } +MLIR_CAPI_EXPORTED void mlirTpuRegisterMosaicSerdePass() { + mlir::tpu::registerMosaicSerdePass(); +} + #include "mlir/CAPI/Pass.h" // IWYU pragma: keep #include "mlir/CAPI/Support.h" // IWYU pragma: keep diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h index 5b2a7009e..18b108c84 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h @@ -19,6 +19,7 @@ limitations under the License. #ifndef JAXLIB_MOSAIC_DIALECT_TPU_INTEGRATIONS_C_TPU_DIALECT_H_ #define JAXLIB_MOSAIC_DIALECT_TPU_INTEGRATIONS_C_TPU_DIALECT_H_ +#include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h" #ifndef __cplusplus #include #endif @@ -234,6 +235,10 @@ MLIR_CAPI_EXPORTED MlirValue mlirTpuRelayout(MlirTpuInsertionPoint insertion_point, MlirValue val, MlirTpuVectorLayout src, MlirTpuVectorLayout dst, MlirTpuApplyVectorLayoutContext ctx); + + +MLIR_CAPI_EXPORTED void mlirTpuRegisterMosaicSerdePass(); + #ifdef __cplusplus } #endif diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 70365a32a..ed5c2f2da 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -786,13 +786,6 @@ def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::Fun let constructor = "::mlir::tpu::createDebugAssertInsertionPass()"; } -def MosaicSerdePass : Pass<"mosaic-serde", "::mlir::ModuleOp"> { - let options = [ - Option<"serialize", "serialize", "bool", "", "">, - Option<"target_version", "target-version", "int", "", ""> // Only used when serialize=true. - ]; -} - def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mlir::func::FuncOp"> { let dependentDialects = [ "::mlir::func::FuncDialect", diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index a8569acc6..c01e56589 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/include/mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/layout.h" #include "jaxlib/mosaic/dialect/tpu/tpu_enums.h.inc" +#include "jaxlib/mosaic/dialect/tpu/transforms/serde.h" #include "xla/layout.h" namespace mlir::tpu { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 6717e3a3e..9c77ce466 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -13,21 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// We need to keep some extra headers for the code in tpu_passes.h.inc. +#include "jaxlib/mosaic/dialect/tpu/transforms/serde.h" -#include // IWYU pragma: keep +#include +#include #include #include #include +#include +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" -#include "mlir/Pass/Pass.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" +#include "llvm/include/llvm/ADT/StringMap.h" +#include "mlir/include/mlir/IR/Attributes.h" #include "mlir/include/mlir/IR/BuiltinAttributes.h" #include "mlir/include/mlir/IR/OpDefinition.h" #include "mlir/include/mlir/IR/OperationSupport.h" @@ -36,9 +39,6 @@ limitations under the License. namespace mlir::tpu { -#define GEN_PASS_DEF_MOSAICSERDEPASS -#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" - namespace { constexpr std::string_view kMangledDialect = "stable_mosaic."; @@ -183,107 +183,101 @@ const llvm::StringMap& downgrade_rules() { return *rules; } -struct MosaicSerdePass : public impl::MosaicSerdePassBase { - using Base::Base; - - void runOnOperation() override { - ModuleOp module = getOperation(); - if (!serialize.hasValue()) { - module.emitError("serialize option must be specified"); - return signalPassFailure(); - } - int serialize_version = - target_version.hasValue() ? target_version : kVersion; - if (serialize && serialize_version > kVersion) { - module.emitError("The highest supported version is ") - << kVersion << " but requested serialization at version " - << serialize_version; - return signalPassFailure(); - } - if (serialize && !module->getContext()->allowsUnregisteredDialects()) { - module.emitError() << "Cannot serialize within a context that does not " - "allow unregistered dialects."; - signalPassFailure(); - return; - } - int version = kVersion; - if (serialize) { - module->setAttr( - kVersionAttrName, - IntegerAttr::get(IntegerType::get(module->getContext(), 64), - serialize_version)); - } else { - IntegerAttr version_attr = - module->getAttrOfType(kVersionAttrName); - if (!version_attr) { - module->emitError("Missing or invalid Mosaic version attribute"); - signalPassFailure(); - return; - } - if (version_attr.getInt() > kVersion) { - module->emitError("Unsupported Mosaic version: expected <= ") - << kVersion << " but got " << version_attr.getInt(); - signalPassFailure(); - return; - } - version = version_attr.getInt(); - module->removeAttr(kVersionAttrName); - } - std::string name_storage; - auto result = module.walk([&](Operation* op) { - if (isa(op)) { // Don't mangle the ModuleOp itself. - return WalkResult::advance(); - } - std::optional new_name; - if (serialize) { - auto new_name_str = mangle(op->getName().getStringRef(), &name_storage); - new_name = OperationName(new_name_str, op->getContext()); - } else { - if (auto demangled = demangle(op->getName().getStringRef())) { - auto new_name_str = *demangled; - if (auto registered = RegisteredOperationName::lookup( - new_name_str, op->getContext())) { - new_name = *registered; - } else { - new_name = OperationName(new_name_str, op->getContext()); - } - } else { - op->emitError("Operation not in a serialized form"); - return WalkResult::interrupt(); - } - // Upgrade the op to the current version, if needed. - if (const auto rule = upgrade_rules().find(new_name->getStringRef()); - rule != upgrade_rules().end()) { - if (rule->second(op, version).failed()) { - return WalkResult::interrupt(); - } - } - } - auto new_op = Operation::create( - op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(), - op->getAttrs(), nullptr, op->getSuccessors(), op->getRegions()); - // Downgrade the op to the target version, if needed. - if (serialize && kVersion != serialize_version) { - if (const auto rule = - downgrade_rules().find(op->getName().getStringRef()); - rule != downgrade_rules().end()) { - if (rule->second(new_op, serialize_version).failed()) { - return WalkResult::interrupt(); - } - } - } - op->getBlock()->getOperations().insertAfter(Block::iterator(op), new_op); - op->replaceAllUsesWith(new_op->getResults()); - op->erase(); - return WalkResult::advance(); - }); - if (result.wasInterrupted()) { - signalPassFailure(); - return; - } - } -}; - } // namespace -} // namespace mlir::tpu \ No newline at end of file +void MosaicSerdePass::runOnOperation() { + ModuleOp module = getOperation(); + if (!serialize.hasValue()) { + module.emitError("serialize option must be specified"); + return signalPassFailure(); + } + int serialize_version = target_version.hasValue() ? target_version : kVersion; + if (serialize && serialize_version > kVersion) { + module.emitError("The highest supported version is ") + << kVersion << " but requested serialization at version " + << serialize_version; + return signalPassFailure(); + } + if (serialize && !module->getContext()->allowsUnregisteredDialects()) { + module.emitError() << "Cannot serialize within a context that does not " + "allow unregistered dialects."; + signalPassFailure(); + return; + } + int version = kVersion; + if (serialize) { + module->setAttr(kVersionAttrName, + IntegerAttr::get(IntegerType::get(module->getContext(), 64), + serialize_version)); + } else { + IntegerAttr version_attr = + module->getAttrOfType(kVersionAttrName); + if (!version_attr) { + module->emitError("Missing or invalid Mosaic version attribute"); + signalPassFailure(); + return; + } + if (version_attr.getInt() > kVersion) { + module->emitError("Unsupported Mosaic version: expected <= ") + << kVersion << " but got " << version_attr.getInt(); + signalPassFailure(); + return; + } + version = version_attr.getInt(); + module->removeAttr(kVersionAttrName); + } + std::string name_storage; + auto result = module.walk([&](Operation* op) { + if (isa(op)) { // Don't mangle the ModuleOp itself. + return WalkResult::advance(); + } + std::optional new_name; + if (serialize) { + auto new_name_str = mangle(op->getName().getStringRef(), &name_storage); + new_name = OperationName(new_name_str, op->getContext()); + } else { + if (auto demangled = demangle(op->getName().getStringRef())) { + auto new_name_str = *demangled; + if (auto registered = RegisteredOperationName::lookup( + new_name_str, op->getContext())) { + new_name = *registered; + } else { + new_name = OperationName(new_name_str, op->getContext()); + } + } else { + op->emitError("Operation not in a serialized form"); + return WalkResult::interrupt(); + } + // Upgrade the op to the current version, if needed. + if (const auto rule = upgrade_rules().find(new_name->getStringRef()); + rule != upgrade_rules().end()) { + if (rule->second(op, version).failed()) { + return WalkResult::interrupt(); + } + } + } + auto new_op = Operation::create( + op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(), + op->getAttrs(), nullptr, op->getSuccessors(), op->getRegions()); + // Downgrade the op to the target version, if needed. + if (serialize && kVersion != serialize_version) { + if (const auto rule = + downgrade_rules().find(op->getName().getStringRef()); + rule != downgrade_rules().end()) { + if (rule->second(new_op, serialize_version).failed()) { + return WalkResult::interrupt(); + } + } + } + op->getBlock()->getOperations().insertAfter(Block::iterator(op), new_op); + op->replaceAllUsesWith(new_op->getResults()); + op->erase(); + return WalkResult::advance(); + }); + if (result.wasInterrupted()) { + signalPassFailure(); + return; + } +} + +} // namespace mlir::tpu diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.h b/jaxlib/mosaic/dialect/tpu/transforms/serde.h new file mode 100644 index 000000000..8685918d3 --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.h @@ -0,0 +1,70 @@ +#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_ +#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_ + +#include +#include + +#include "llvm/include/llvm/ADT/StringRef.h" +#include "llvm/include/llvm/Support/CommandLine.h" +#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/include/mlir/Pass/Pass.h" +#include "mlir/include/mlir/Pass/PassRegistry.h" +#include "jaxlib/pass_boilerplate.h" + +namespace mlir::tpu { + +struct MosaicSerdePassOptions { + bool serialize; + int target_version; +}; + +struct MosaicSerdePass : public jaxlib::mlir::Pass { + using jaxlib::mlir::Pass::Pass; + + static constexpr llvm::StringLiteral kArgumentName = "mosaic-serde"; + static constexpr llvm::StringLiteral kPassName = "MosaicSerdePass"; + + MosaicSerdePass() = default; + + explicit MosaicSerdePass(MosaicSerdePassOptions options) { + serialize = options.serialize; + target_version = options.target_version; + } + + MosaicSerdePass(const MosaicSerdePass &other) { + serialize = other.serialize; + target_version = other.target_version; + } + + MosaicSerdePass &operator=(const MosaicSerdePass &other) { + serialize = other.serialize; + target_version = other.target_version; + return *this; + } + + void runOnOperation(); + + protected: + ::mlir::Pass::Option serialize{*this, "serialize", llvm::cl::desc("")}; + ::mlir::Pass::Option target_version{*this, "target-version", + llvm::cl::desc("")}; +}; + +inline std::unique_ptr<::mlir::Pass> createMosaicSerdePass() { + return std::make_unique(); +} + +inline std::unique_ptr<::mlir::Pass> createMosaicSerdePass( + MosaicSerdePassOptions options) { + return std::make_unique(std::move(options)); +} + +inline void registerMosaicSerdePass() { + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return createMosaicSerdePass(); + }); +} + +} // namespace mlir::tpu + +#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_ diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index cb52488e7..2139a2666 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -47,10 +47,10 @@ cc_library( ], hdrs = [ "launch_lowering.h", - "pass_boilerplate.h", "passes.h", ], deps = [ + "//jaxlib:pass_boilerplate", "@llvm-project//llvm:Support", "@llvm-project//mlir:DataLayoutInterfaces", "@llvm-project//mlir:FuncDialect", diff --git a/jaxlib/mosaic/gpu/passes.cc b/jaxlib/mosaic/gpu/passes.cc index 50404f11a..cee34ddae 100644 --- a/jaxlib/mosaic/gpu/passes.cc +++ b/jaxlib/mosaic/gpu/passes.cc @@ -29,7 +29,7 @@ limitations under the License. #include "mlir/include/mlir/Pass/PassRegistry.h" #include "mlir/include/mlir/Support/LLVM.h" #include "mlir/include/mlir/Transforms/DialectConversion.h" -#include "jaxlib/mosaic/gpu/pass_boilerplate.h" +#include "jaxlib/pass_boilerplate.h" namespace mosaic { namespace gpu { @@ -37,9 +37,9 @@ namespace gpu { namespace { class ConvertGpuToLLVMPass - : public mosaic::gpu::Pass { + : public jaxlib::mlir::Pass { public: - using mosaic::gpu::Pass::Pass; + using jaxlib::mlir::Pass::Pass; static constexpr llvm::StringLiteral kArgumentName = "mosaic-convert-gpu-to-llvm"; static constexpr llvm::StringLiteral kPassName = "ConvertGpuToLLVMPass"; @@ -71,9 +71,9 @@ class ConvertGpuToLLVMPass // We only use arrays to pass in TMA descriptors, which is why we also // require 64-byte alignment. class ByvalInsertionPass - : public mosaic::gpu::Pass { + : public jaxlib::mlir::Pass { public: - using mosaic::gpu::Pass::Pass; + using jaxlib::mlir::Pass::Pass; static constexpr llvm::StringLiteral kArgumentName = "mosaic-byval-insertion"; static constexpr llvm::StringLiteral kPassName = "ByvalInsertionPass"; diff --git a/jaxlib/mosaic/gpu/pass_boilerplate.h b/jaxlib/pass_boilerplate.h similarity index 90% rename from jaxlib/mosaic/gpu/pass_boilerplate.h rename to jaxlib/pass_boilerplate.h index b0241fca9..b9754a873 100644 --- a/jaxlib/mosaic/gpu/pass_boilerplate.h +++ b/jaxlib/pass_boilerplate.h @@ -13,15 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_ -#define JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_ +#ifndef JAXLIB_PASS_BOILERPLATE_H_ +#define JAXLIB_PASS_BOILERPLATE_H_ + +#include #include "mlir/include/mlir/IR/DialectRegistry.h" #include "mlir/include/mlir/Pass/Pass.h" #include "mlir/include/mlir/Support/LLVM.h" #include "mlir/include/mlir/Support/TypeID.h" -namespace mosaic { -namespace gpu { + +namespace jaxlib { +namespace mlir { template class Pass : public ::mlir::OperationPass { @@ -58,7 +61,7 @@ class Pass : public ::mlir::OperationPass { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(This) }; -} // namespace gpu -} // namespace mosaic +} // namespace mlir +} // namespace jaxlib -#endif // JAXLIB_MOSAIC_GPU_PASS_BOILERPLATE_H_ +#endif // JAXLIB_PASS_BOILERPLATE_H_