diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 1aef1ebdd..16fa666b6 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -60,6 +60,7 @@ cc_library( deps = [ ":tpu_inc_gen", "//jaxlib:pass_boilerplate", + "//jaxlib/mosaic:serde", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", @@ -72,6 +73,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ControlFlowDialect", + "@llvm-project//mlir:DataLayoutInterfaces", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", @@ -255,3 +257,16 @@ filegroup( "dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc", ], ) + +cc_library( + name = "serde", + srcs = ["serde.cc"], + hdrs = ["serde.h"], + # compatible with libtpu + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 9c77ce466..0981c263d 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -16,10 +16,6 @@ limitations under the License. #include "jaxlib/mosaic/dialect/tpu/transforms/serde.h" #include -#include -#include -#include -#include #include #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -36,34 +32,19 @@ limitations under the License. #include "mlir/include/mlir/IR/OperationSupport.h" #include "mlir/include/mlir/Support/LogicalResult.h" #include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +#include "jaxlib/mosaic/serde.h" namespace mlir::tpu { namespace { -constexpr std::string_view kMangledDialect = "stable_mosaic."; +constexpr StringRef kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; // When this is bumped, we should file a TODO to update the forward-compatible // version in tpu_custom_call.py in a month! constexpr int kVersion = 3; -StringRef mangle(StringRef name, std::string* storage) { - storage->clear(); - storage->reserve(kMangledDialect.size() + name.size()); - storage->insert(storage->end(), kMangledDialect.begin(), - kMangledDialect.end()); - storage->insert(storage->end(), name.begin(), name.end()); - return *storage; -} - -std::optional demangle(StringRef name) { - if (!name.starts_with(kMangledDialect)) { - return std::nullopt; - } - return name.drop_front(kMangledDialect.size()); -} - -using rule_type = std::function; +using SerdeRuleType = jaxlib::mosaic::SerdeRuleType; LogicalResult enqueue_dma_upgrade(Operation* op, int version) { // Added AttrSizedOperandSegments and core_id in version 2. @@ -164,18 +145,17 @@ LogicalResult vector_multi_dim_reduce_downgrade(Operation* op, int version) { return success(); } -const llvm::StringMap& upgrade_rules() { - static auto rules = new llvm::StringMap{ +const llvm::StringMap& upgrade_rules() { + static auto rules = new llvm::StringMap{ {EnqueueDMAOp::getOperationName(), enqueue_dma_upgrade}, {SemaphoreSignalOp::getOperationName(), semaphore_signal_upgrade}, {vector::MultiDimReductionOp::getOperationName(), - vector_multi_dim_reduce_upgrade} - }; + vector_multi_dim_reduce_upgrade}}; return *rules; } -const llvm::StringMap& downgrade_rules() { - static auto rules = new llvm::StringMap{ +const llvm::StringMap& downgrade_rules() { + static auto rules = new llvm::StringMap{ {EnqueueDMAOp::getOperationName(), enqueue_dma_downgrade}, {SemaphoreSignalOp::getOperationName(), semaphore_signal_downgrade}, {vector::MultiDimReductionOp::getOperationName(), @@ -191,92 +171,17 @@ void MosaicSerdePass::runOnOperation() { module.emitError("serialize option must be specified"); return signalPassFailure(); } - int serialize_version = target_version.hasValue() ? target_version : kVersion; - if (serialize && serialize_version > kVersion) { - module.emitError("The highest supported version is ") - << kVersion << " but requested serialization at version " - << serialize_version; - return signalPassFailure(); - } - if (serialize && !module->getContext()->allowsUnregisteredDialects()) { - module.emitError() << "Cannot serialize within a context that does not " - "allow unregistered dialects."; - signalPassFailure(); - return; - } - int version = kVersion; + int serialize_version = -1; if (serialize) { - module->setAttr(kVersionAttrName, - IntegerAttr::get(IntegerType::get(module->getContext(), 64), - serialize_version)); - } else { - IntegerAttr version_attr = - module->getAttrOfType(kVersionAttrName); - if (!version_attr) { - module->emitError("Missing or invalid Mosaic version attribute"); - signalPassFailure(); - return; - } - if (version_attr.getInt() > kVersion) { - module->emitError("Unsupported Mosaic version: expected <= ") - << kVersion << " but got " << version_attr.getInt(); - signalPassFailure(); - return; - } - version = version_attr.getInt(); - module->removeAttr(kVersionAttrName); + serialize_version = target_version.hasValue() ? target_version : kVersion; } - std::string name_storage; - auto result = module.walk([&](Operation* op) { - if (isa(op)) { // Don't mangle the ModuleOp itself. - return WalkResult::advance(); - } - std::optional new_name; - if (serialize) { - auto new_name_str = mangle(op->getName().getStringRef(), &name_storage); - new_name = OperationName(new_name_str, op->getContext()); - } else { - if (auto demangled = demangle(op->getName().getStringRef())) { - auto new_name_str = *demangled; - if (auto registered = RegisteredOperationName::lookup( - new_name_str, op->getContext())) { - new_name = *registered; - } else { - new_name = OperationName(new_name_str, op->getContext()); - } - } else { - op->emitError("Operation not in a serialized form"); - return WalkResult::interrupt(); - } - // Upgrade the op to the current version, if needed. - if (const auto rule = upgrade_rules().find(new_name->getStringRef()); - rule != upgrade_rules().end()) { - if (rule->second(op, version).failed()) { - return WalkResult::interrupt(); - } - } - } - auto new_op = Operation::create( - op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(), - op->getAttrs(), nullptr, op->getSuccessors(), op->getRegions()); - // Downgrade the op to the target version, if needed. - if (serialize && kVersion != serialize_version) { - if (const auto rule = - downgrade_rules().find(op->getName().getStringRef()); - rule != downgrade_rules().end()) { - if (rule->second(new_op, serialize_version).failed()) { - return WalkResult::interrupt(); - } - } - } - op->getBlock()->getOperations().insertAfter(Block::iterator(op), new_op); - op->replaceAllUsesWith(new_op->getResults()); - op->erase(); - return WalkResult::advance(); - }); - if (result.wasInterrupted()) { + if (failed(jaxlib::mosaic::RunSerde( + module, upgrade_rules(), downgrade_rules(), serialize, + {.dialect_prefix = kMangledDialect, + .highest_version = kVersion, + .version_attr_name = kVersionAttrName, + .serialize_version = serialize_version}))) { signalPassFailure(); - return; } } diff --git a/jaxlib/mosaic/serde.cc b/jaxlib/mosaic/serde.cc new file mode 100644 index 000000000..88bca44bf --- /dev/null +++ b/jaxlib/mosaic/serde.cc @@ -0,0 +1,148 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/serde.h" + +#include +#include + +#include "llvm/include/llvm/ADT/StringMap.h" +#include "llvm/include/llvm/ADT/StringRef.h" +#include "mlir/include/mlir/IR/BuiltinAttributes.h" +#include "mlir/include/mlir/IR/BuiltinOps.h" +#include "mlir/include/mlir/IR/Operation.h" +#include "mlir/include/mlir/IR/OperationSupport.h" +#include "mlir/include/mlir/IR/Visitors.h" +#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/include/mlir/Support/LLVM.h" + +namespace jaxlib::mosaic { + +namespace { + +llvm::StringRef mangle(llvm::StringRef name, llvm::StringRef prefix, + std::string* storage) { + storage->clear(); + storage->reserve(prefix.size() + name.size()); + storage->insert(storage->end(), prefix.begin(), prefix.end()); + storage->insert(storage->end(), name.begin(), name.end()); + return *storage; +} + +std::optional demangle(llvm::StringRef name, + llvm::StringRef prefix) { + if (!name.starts_with(prefix)) { + return std::nullopt; + } + return name.drop_front(prefix.size()); +} + +} // namespace + +mlir::LogicalResult RunSerde( + mlir::ModuleOp module, const llvm::StringMap& upgrade_rules, + const llvm::StringMap& downgrade_rules, bool serialize, + SerdeOptions options) { + int version = options.highest_version; + int serialize_version = options.serialize_version; + if (!serialize && serialize_version != -1) { + module.emitError("Cannot deserialize to a specific version"); + return mlir::failure(); + } + if (serialize && serialize_version > version) { + module.emitError("The highest supported version is ") + << version << " but requested serialization at version " + << serialize_version; + return mlir::failure(); + } + if (serialize && !module->getContext()->allowsUnregisteredDialects()) { + module.emitError() << "Cannot serialize within a context that does not " + "allow unregistered dialects"; + return mlir::failure(); + } + if (serialize) { + module->setAttr( + options.version_attr_name, + mlir::IntegerAttr::get(mlir::IntegerType::get(module->getContext(), 64), + serialize_version)); + } else { + mlir::IntegerAttr version_attr = + module->getAttrOfType(options.version_attr_name); + if (!version_attr) { + module->emitError("Missing or invalid version attribute"); + return mlir::failure(); + } + if (version_attr.getInt() > version) { + module->emitError("Unsupported version: expected <= ") + << version << " but got " << version_attr.getInt(); + return mlir::failure(); + } + version = version_attr.getInt(); + module->removeAttr(options.version_attr_name); + } + std::string storage; + auto result = module.walk([&](mlir::Operation* op) { + if (mlir::isa(op)) { // Don't mangle the ModuleOp itself. + return mlir::WalkResult::advance(); + } + std::optional new_name; + if (serialize) { + auto new_name_str = mangle(op->getName().getStringRef(), + options.dialect_prefix, &storage); + new_name = mlir::OperationName(new_name_str, op->getContext()); + } else { + if (auto demangled = + demangle(op->getName().getStringRef(), options.dialect_prefix)) { + auto new_name_str = *demangled; + if (auto registered = mlir::RegisteredOperationName::lookup( + new_name_str, op->getContext())) { + new_name = *registered; + } else { + new_name = mlir::OperationName(new_name_str, op->getContext()); + } + } else { + op->emitError("Operation not in a serialized form"); + return mlir::WalkResult::interrupt(); + } + // Upgrade the op to the current version, if needed. + if (const auto rule = upgrade_rules.find(new_name->getStringRef()); + rule != upgrade_rules.end()) { + if (rule->second(op, version).failed()) { + return mlir::WalkResult::interrupt(); + } + } + } + auto new_op = mlir::Operation::create( + op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(), + op->getAttrs(), nullptr, op->getSuccessors(), op->getRegions()); + // Downgrade the op to the target version, if needed. + if (serialize && version != serialize_version) { + if (const auto rule = downgrade_rules.find(op->getName().getStringRef()); + rule != downgrade_rules.end()) { + if (rule->second(new_op, serialize_version).failed()) { + return mlir::WalkResult::interrupt(); + } + } + } + op->getBlock()->getOperations().insertAfter(mlir::Block::iterator(op), + new_op); + op->replaceAllUsesWith(new_op->getResults()); + op->erase(); + return mlir::WalkResult::advance(); + }); + return result.wasInterrupted() ? mlir::failure() : mlir::success(); +} + +} // namespace jaxlib::mosaic diff --git a/jaxlib/mosaic/serde.h b/jaxlib/mosaic/serde.h new file mode 100644 index 000000000..762d9e5da --- /dev/null +++ b/jaxlib/mosaic/serde.h @@ -0,0 +1,56 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_SERDE_H_ +#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_SERDE_H_ + +#include + +#include "llvm/include/llvm/ADT/StringMap.h" +#include "llvm/include/llvm/ADT/StringRef.h" +#include "mlir/include/mlir/IR/BuiltinAttributes.h" +#include "mlir/include/mlir/IR/BuiltinOps.h" +#include "mlir/include/mlir/Support/LLVM.h" + +namespace jaxlib::mosaic { + +struct SerdeOptions { + llvm::StringRef dialect_prefix; // mangled dialect prefix + int highest_version; // the highest supported version + llvm::StringRef version_attr_name; + int serialize_version; // target version for serialization, must be -1 when + // deserializing +}; + +// A rule for upgrading or downgrading an operation. +// +// The first argument is the operation to upgrade/downgrade. +// The second argument is the target version. +// +// The function should return success if the upgrade/downgrade was successful, +// or an error otherwise. +using SerdeRuleType = + std::function<::mlir::LogicalResult(::mlir::Operation *, int)>; + +// Run serialization or deserialization on the given module. +::mlir::LogicalResult RunSerde( + ::mlir::ModuleOp module, + const llvm::StringMap &upgrade_rules, + const llvm::StringMap &downgrade_rules, bool serialize, + SerdeOptions options); + +} // namespace jaxlib::mosaic + +#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_SERDE_H_