[Mosaic] Add a pass for instantiating memory spaces

PiperOrigin-RevId: 564723473
This commit is contained in:
Adam Paszke 2023-09-12 08:04:49 -07:00 committed by jax authors
parent 7dddb507e9
commit dbb0e8f214
3 changed files with 106 additions and 1 deletions

View File

@ -113,8 +113,11 @@ def TPU_TiledLayoutAttr
def TPU_MemorySpace : I32EnumAttr<"MemorySpace", "Memory space", [
I32EnumAttrCase<"kAny", 4294967295, "any">,
// TODO(apaszke): Rename to kXYZ in C++
I32EnumAttrCase<"vmem", 0, "vmem">,
I32EnumAttrCase<"smem", 1, "smem">
I32EnumAttrCase<"smem", 1, "smem">,
I32EnumAttrCase<"kHbm", 2, "hbm">,
I32EnumAttrCase<"kCmem", 3, "cmem">
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tpu";

View File

@ -23,6 +23,9 @@ limitations under the License.
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/Support/LogicalResult.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_enums.h.inc"
#include "xla/layout.h"
@ -57,6 +60,10 @@ createLogicalToPhysicalDeviceIdPass(int64_t total_devices);
std::unique_ptr<OperationPass<func::FuncOp>> createLinalgVectorizationPass();
// Changes the memory space of the value and propagates it through the program.
LogicalResult specializeMemorySpace(TypedValue<MemRefType> value,
MemorySpace memory_space);
// In Mosaic, we often strip tiled layouts from memrefs, for compatibility with
// vector ops. This functions inverts the layout erasure applied to the value.
MemRefType getMemRefType(Value value);

View File

@ -0,0 +1,95 @@
/* Copyright 2023 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/log/check.h"
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "mlir/include/mlir/Support/LogicalResult.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
namespace mlir {
namespace tpu {
namespace {
MemRefType updateMemorySpace(MemRefType ty, Attribute memory_space) {
return MemRefType::get(ty.getShape(), ty.getElementType(), ty.getLayout(),
memory_space);
}
MemRefType updateMemorySpace(MemRefType ty, MemorySpace memory_space) {
return updateMemorySpace(ty,
MemorySpaceAttr::get(ty.getContext(), memory_space));
}
} // namespace
LogicalResult specializeMemorySpace(TypedValue<MemRefType> value,
MemorySpace memory_space) {
MemorySpaceAttr attr =
dyn_cast_if_present<MemorySpaceAttr>(value.getType().getMemorySpace());
if (!attr) {
return failure();
}
MemorySpace current_memory_space = attr.getValue();
if (current_memory_space == memory_space) {
return success(); // Nothing to do here.
} else if (current_memory_space != MemorySpace::kAny) {
return failure(); // Memory space mismatch!
}
value.setType(updateMemorySpace(value.getType(), memory_space));
std::vector<Operation*> to_update(value.getUsers().begin(),
value.getUsers().end());
auto updateResultFrom = [&](Operation* op, MemRefType ty) {
Attribute source_memory_space = ty.getMemorySpace();
CHECK_EQ(op->getNumResults(), 1);
Value result = op->getResult(0);
MemRefType result_type = cast<MemRefType>(result.getType());
if (result_type.getMemorySpace() != source_memory_space) {
result.setType(updateMemorySpace(result_type, source_memory_space));
to_update.insert(to_update.end(), result.getUsers().begin(),
result.getUsers().end());
}
};
while (!to_update.empty()) {
Operation* some_op = to_update.back();
to_update.pop_back();
// Here we only have to handle the operations allowed on refs with
// unspecified memory space.
if (auto op = dyn_cast<tpu::MemRefSliceOp>(some_op)) {
updateResultFrom(op, op.getMemRef().getType());
continue;
}
if (auto op = dyn_cast<tpu::EraseLayoutOp>(some_op)) {
updateResultFrom(op, op.getOperand().getType());
continue;
}
if (auto op = dyn_cast<tpu::EnqueueDMAOp>(some_op)) {
continue; // Nothing to do.
}
if (auto op = dyn_cast<tpu::WaitDMAOp>(some_op)) {
continue; // Nothing to do.
}
some_op->emitOpError(
"Failed to propagate memory space update through this operation");
return failure();
}
return success();
}
} // namespace tpu
} // namespace mlir