[mosaic] Extracted serialization pass traversal logic into a reusable function

I will use it to implement Mosaic GPU serialization pass in a follow up.

PiperOrigin-RevId: 716156650
This commit is contained in:
Sergei Lebedev 2025-01-16 02:57:31 -08:00 committed by jax authors
parent 9a60e6fce4
commit 4221f109d1
4 changed files with 235 additions and 111 deletions

View File

@ -60,6 +60,7 @@ cc_library(
deps = [
":tpu_inc_gen",
"//jaxlib:pass_boilerplate",
"//jaxlib/mosaic:serde",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/hash",
@ -72,6 +73,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ControlFlowDialect",
"@llvm-project//mlir:DataLayoutInterfaces",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
@ -255,3 +257,16 @@ filegroup(
"dialect/tpu/transforms/extensions/infer_vector_layout_extensions.cc",
],
)
cc_library(
name = "serde",
srcs = ["serde.cc"],
hdrs = ["serde.h"],
# compatible with libtpu
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:DataLayoutInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)

View File

@ -16,10 +16,6 @@ limitations under the License.
#include "jaxlib/mosaic/dialect/tpu/transforms/serde.h"
#include <cstdint>
#include <functional>
#include <optional>
#include <string>
#include <string_view>
#include <vector>
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@ -36,34 +32,19 @@ limitations under the License.
#include "mlir/include/mlir/IR/OperationSupport.h"
#include "mlir/include/mlir/Support/LogicalResult.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/serde.h"
namespace mlir::tpu {
namespace {
constexpr std::string_view kMangledDialect = "stable_mosaic.";
constexpr StringRef kMangledDialect = "stable_mosaic.";
constexpr StringRef kVersionAttrName = "stable_mosaic.version";
// When this is bumped, we should file a TODO to update the forward-compatible
// version in tpu_custom_call.py in a month!
constexpr int kVersion = 3;
StringRef mangle(StringRef name, std::string* storage) {
storage->clear();
storage->reserve(kMangledDialect.size() + name.size());
storage->insert(storage->end(), kMangledDialect.begin(),
kMangledDialect.end());
storage->insert(storage->end(), name.begin(), name.end());
return *storage;
}
std::optional<StringRef> demangle(StringRef name) {
if (!name.starts_with(kMangledDialect)) {
return std::nullopt;
}
return name.drop_front(kMangledDialect.size());
}
using rule_type = std::function<LogicalResult(Operation*, int)>;
using SerdeRuleType = jaxlib::mosaic::SerdeRuleType;
LogicalResult enqueue_dma_upgrade(Operation* op, int version) {
// Added AttrSizedOperandSegments and core_id in version 2.
@ -164,18 +145,17 @@ LogicalResult vector_multi_dim_reduce_downgrade(Operation* op, int version) {
return success();
}
const llvm::StringMap<rule_type>& upgrade_rules() {
static auto rules = new llvm::StringMap<rule_type>{
const llvm::StringMap<SerdeRuleType>& upgrade_rules() {
static auto rules = new llvm::StringMap<SerdeRuleType>{
{EnqueueDMAOp::getOperationName(), enqueue_dma_upgrade},
{SemaphoreSignalOp::getOperationName(), semaphore_signal_upgrade},
{vector::MultiDimReductionOp::getOperationName(),
vector_multi_dim_reduce_upgrade}
};
vector_multi_dim_reduce_upgrade}};
return *rules;
}
const llvm::StringMap<rule_type>& downgrade_rules() {
static auto rules = new llvm::StringMap<rule_type>{
const llvm::StringMap<SerdeRuleType>& downgrade_rules() {
static auto rules = new llvm::StringMap<SerdeRuleType>{
{EnqueueDMAOp::getOperationName(), enqueue_dma_downgrade},
{SemaphoreSignalOp::getOperationName(), semaphore_signal_downgrade},
{vector::MultiDimReductionOp::getOperationName(),
@ -191,92 +171,17 @@ void MosaicSerdePass::runOnOperation() {
module.emitError("serialize option must be specified");
return signalPassFailure();
}
int serialize_version = target_version.hasValue() ? target_version : kVersion;
if (serialize && serialize_version > kVersion) {
module.emitError("The highest supported version is ")
<< kVersion << " but requested serialization at version "
<< serialize_version;
return signalPassFailure();
}
if (serialize && !module->getContext()->allowsUnregisteredDialects()) {
module.emitError() << "Cannot serialize within a context that does not "
"allow unregistered dialects.";
signalPassFailure();
return;
}
int version = kVersion;
int serialize_version = -1;
if (serialize) {
module->setAttr(kVersionAttrName,
IntegerAttr::get(IntegerType::get(module->getContext(), 64),
serialize_version));
} else {
IntegerAttr version_attr =
module->getAttrOfType<IntegerAttr>(kVersionAttrName);
if (!version_attr) {
module->emitError("Missing or invalid Mosaic version attribute");
signalPassFailure();
return;
}
if (version_attr.getInt() > kVersion) {
module->emitError("Unsupported Mosaic version: expected <= ")
<< kVersion << " but got " << version_attr.getInt();
signalPassFailure();
return;
}
version = version_attr.getInt();
module->removeAttr(kVersionAttrName);
serialize_version = target_version.hasValue() ? target_version : kVersion;
}
std::string name_storage;
auto result = module.walk([&](Operation* op) {
if (isa<ModuleOp>(op)) { // Don't mangle the ModuleOp itself.
return WalkResult::advance();
}
std::optional<OperationName> new_name;
if (serialize) {
auto new_name_str = mangle(op->getName().getStringRef(), &name_storage);
new_name = OperationName(new_name_str, op->getContext());
} else {
if (auto demangled = demangle(op->getName().getStringRef())) {
auto new_name_str = *demangled;
if (auto registered = RegisteredOperationName::lookup(
new_name_str, op->getContext())) {
new_name = *registered;
} else {
new_name = OperationName(new_name_str, op->getContext());
}
} else {
op->emitError("Operation not in a serialized form");
return WalkResult::interrupt();
}
// Upgrade the op to the current version, if needed.
if (const auto rule = upgrade_rules().find(new_name->getStringRef());
rule != upgrade_rules().end()) {
if (rule->second(op, version).failed()) {
return WalkResult::interrupt();
}
}
}
auto new_op = Operation::create(
op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(),
op->getAttrs(), nullptr, op->getSuccessors(), op->getRegions());
// Downgrade the op to the target version, if needed.
if (serialize && kVersion != serialize_version) {
if (const auto rule =
downgrade_rules().find(op->getName().getStringRef());
rule != downgrade_rules().end()) {
if (rule->second(new_op, serialize_version).failed()) {
return WalkResult::interrupt();
}
}
}
op->getBlock()->getOperations().insertAfter(Block::iterator(op), new_op);
op->replaceAllUsesWith(new_op->getResults());
op->erase();
return WalkResult::advance();
});
if (result.wasInterrupted()) {
if (failed(jaxlib::mosaic::RunSerde(
module, upgrade_rules(), downgrade_rules(), serialize,
{.dialect_prefix = kMangledDialect,
.highest_version = kVersion,
.version_attr_name = kVersionAttrName,
.serialize_version = serialize_version}))) {
signalPassFailure();
return;
}
}

148
jaxlib/mosaic/serde.cc Normal file
View File

@ -0,0 +1,148 @@
/* Copyright 2025 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/mosaic/serde.h"
#include <optional>
#include <string>
#include "llvm/include/llvm/ADT/StringMap.h"
#include "llvm/include/llvm/ADT/StringRef.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include "mlir/include/mlir/IR/BuiltinOps.h"
#include "mlir/include/mlir/IR/Operation.h"
#include "mlir/include/mlir/IR/OperationSupport.h"
#include "mlir/include/mlir/IR/Visitors.h"
#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/include/mlir/Support/LLVM.h"
namespace jaxlib::mosaic {
namespace {
llvm::StringRef mangle(llvm::StringRef name, llvm::StringRef prefix,
std::string* storage) {
storage->clear();
storage->reserve(prefix.size() + name.size());
storage->insert(storage->end(), prefix.begin(), prefix.end());
storage->insert(storage->end(), name.begin(), name.end());
return *storage;
}
std::optional<llvm::StringRef> demangle(llvm::StringRef name,
llvm::StringRef prefix) {
if (!name.starts_with(prefix)) {
return std::nullopt;
}
return name.drop_front(prefix.size());
}
} // namespace
mlir::LogicalResult RunSerde(
mlir::ModuleOp module, const llvm::StringMap<SerdeRuleType>& upgrade_rules,
const llvm::StringMap<SerdeRuleType>& downgrade_rules, bool serialize,
SerdeOptions options) {
int version = options.highest_version;
int serialize_version = options.serialize_version;
if (!serialize && serialize_version != -1) {
module.emitError("Cannot deserialize to a specific version");
return mlir::failure();
}
if (serialize && serialize_version > version) {
module.emitError("The highest supported version is ")
<< version << " but requested serialization at version "
<< serialize_version;
return mlir::failure();
}
if (serialize && !module->getContext()->allowsUnregisteredDialects()) {
module.emitError() << "Cannot serialize within a context that does not "
"allow unregistered dialects";
return mlir::failure();
}
if (serialize) {
module->setAttr(
options.version_attr_name,
mlir::IntegerAttr::get(mlir::IntegerType::get(module->getContext(), 64),
serialize_version));
} else {
mlir::IntegerAttr version_attr =
module->getAttrOfType<mlir::IntegerAttr>(options.version_attr_name);
if (!version_attr) {
module->emitError("Missing or invalid version attribute");
return mlir::failure();
}
if (version_attr.getInt() > version) {
module->emitError("Unsupported version: expected <= ")
<< version << " but got " << version_attr.getInt();
return mlir::failure();
}
version = version_attr.getInt();
module->removeAttr(options.version_attr_name);
}
std::string storage;
auto result = module.walk([&](mlir::Operation* op) {
if (mlir::isa<mlir::ModuleOp>(op)) { // Don't mangle the ModuleOp itself.
return mlir::WalkResult::advance();
}
std::optional<mlir::OperationName> new_name;
if (serialize) {
auto new_name_str = mangle(op->getName().getStringRef(),
options.dialect_prefix, &storage);
new_name = mlir::OperationName(new_name_str, op->getContext());
} else {
if (auto demangled =
demangle(op->getName().getStringRef(), options.dialect_prefix)) {
auto new_name_str = *demangled;
if (auto registered = mlir::RegisteredOperationName::lookup(
new_name_str, op->getContext())) {
new_name = *registered;
} else {
new_name = mlir::OperationName(new_name_str, op->getContext());
}
} else {
op->emitError("Operation not in a serialized form");
return mlir::WalkResult::interrupt();
}
// Upgrade the op to the current version, if needed.
if (const auto rule = upgrade_rules.find(new_name->getStringRef());
rule != upgrade_rules.end()) {
if (rule->second(op, version).failed()) {
return mlir::WalkResult::interrupt();
}
}
}
auto new_op = mlir::Operation::create(
op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(),
op->getAttrs(), nullptr, op->getSuccessors(), op->getRegions());
// Downgrade the op to the target version, if needed.
if (serialize && version != serialize_version) {
if (const auto rule = downgrade_rules.find(op->getName().getStringRef());
rule != downgrade_rules.end()) {
if (rule->second(new_op, serialize_version).failed()) {
return mlir::WalkResult::interrupt();
}
}
}
op->getBlock()->getOperations().insertAfter(mlir::Block::iterator(op),
new_op);
op->replaceAllUsesWith(new_op->getResults());
op->erase();
return mlir::WalkResult::advance();
});
return result.wasInterrupted() ? mlir::failure() : mlir::success();
}
} // namespace jaxlib::mosaic

56
jaxlib/mosaic/serde.h Normal file
View File

@ -0,0 +1,56 @@
/* Copyright 2025 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_SERDE_H_
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_SERDE_H_
#include <functional>
#include "llvm/include/llvm/ADT/StringMap.h"
#include "llvm/include/llvm/ADT/StringRef.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include "mlir/include/mlir/IR/BuiltinOps.h"
#include "mlir/include/mlir/Support/LLVM.h"
namespace jaxlib::mosaic {
struct SerdeOptions {
llvm::StringRef dialect_prefix; // mangled dialect prefix
int highest_version; // the highest supported version
llvm::StringRef version_attr_name;
int serialize_version; // target version for serialization, must be -1 when
// deserializing
};
// A rule for upgrading or downgrading an operation.
//
// The first argument is the operation to upgrade/downgrade.
// The second argument is the target version.
//
// The function should return success if the upgrade/downgrade was successful,
// or an error otherwise.
using SerdeRuleType =
std::function<::mlir::LogicalResult(::mlir::Operation *, int)>;
// Run serialization or deserialization on the given module.
::mlir::LogicalResult RunSerde(
::mlir::ModuleOp module,
const llvm::StringMap<SerdeRuleType> &upgrade_rules,
const llvm::StringMap<SerdeRuleType> &downgrade_rules, bool serialize,
SerdeOptions options);
} // namespace jaxlib::mosaic
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_SERDE_H_