Jevin Jiang 2faf540203 [Mosaic TPU] Add relayout-insertion pass and support bitwidth change for i1 vector relayout
We can use relayout-insertion pass to insert necessary ops and their layouts for relayout before unrolling in apply-vector-layout pass.

PiperOrigin-RevId: 708143852
2024-12-19 19:56:40 -08:00

278 lines
9.9 KiB
C++

/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/mosaic/dialect/tpu/util.h"
#include <array>
#include <cstdint>
#include <optional>
#include <ostream>
#include <string>
#include <utility>
#include "llvm/Support/MathExtras.h"
#include "absl/log/check.h"
#include "absl/types/span.h"
#include "llvm/include/llvm/Support/raw_ostream.h"
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include "mlir/include/mlir/IR/BuiltinTypes.h"
#include "mlir/include/mlir/IR/Value.h"
#include "mlir/include/mlir/IR/ValueRange.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
namespace mlir::tpu {
std::ostream &operator<<(std::ostream &os, Print p) {
std::string s;
llvm::raw_string_ostream tmp_os(s);
p.payload_->print(tmp_os);
os << tmp_os.str();
return os;
}
SmallVector<int64_t> ComputeTileStrides(MemRefType memref_ty,
absl::Span<const int64_t> tiling) {
SmallVector<int64_t> tile_strides(memref_ty.getRank());
int64_t stride = 1;
for (int64_t i = 0; i < memref_ty.getRank(); ++i) {
int64_t idx = memref_ty.getRank() - 1 - i;
int64_t tiling_idx = tiling.size() - 1 - i;
tile_strides[idx] = stride;
if (tiling_idx >= 0) {
stride *= llvm::divideCeil(memref_ty.getShape()[idx], tiling[tiling_idx]);
} else {
stride *= memref_ty.getShape()[idx];
}
}
return tile_strides;
}
std::optional<std::pair<bool, bool>> isTransposedMatmul(
DotDimensionNumbersAttr dim_numbers) {
auto lhs_contracting_dims = dim_numbers.getLhsContractingDims();
auto rhs_contracting_dims = dim_numbers.getRhsContractingDims();
auto lhs_non_contracting_dims = dim_numbers.getLhsNonContractingDims();
auto rhs_non_contracting_dims = dim_numbers.getRhsNonContractingDims();
if (lhs_contracting_dims.size() != 1 || rhs_contracting_dims.size() != 1 ||
lhs_non_contracting_dims.size() != 1 ||
rhs_non_contracting_dims.size() != 1) {
return std::nullopt;
}
int64_t lhs_non_contracting_dim = lhs_non_contracting_dims[0];
int64_t lhs_contracting_dim = lhs_contracting_dims[0];
int64_t rhs_non_contracting_dim = rhs_non_contracting_dims[0];
int64_t rhs_contracting_dim = rhs_contracting_dims[0];
bool lhs_transposed = lhs_non_contracting_dim > lhs_contracting_dim;
bool rhs_transposed = rhs_contracting_dim > rhs_non_contracting_dim;
return std::pair<bool, bool>{lhs_transposed, rhs_transposed};
}
bool canReinterpretToUntiledMemref(TypedValue<MemRefType> tiled_memref,
const std::array<int64_t, 2>& target_shape,
bool allow_minormost_padding) {
MemRefType tiled_memref_ty = tiled_memref.getType();
auto tiled_layout =
dyn_cast<tpu::TiledLayoutAttr>(tiled_memref_ty.getLayout());
ValueRange dynamic_sizes = {};
if (!tiled_layout) {
if (auto erase_op = tiled_memref.getDefiningOp<tpu::EraseLayoutOp>()) {
tiled_memref = erase_op.getOperand();
tiled_memref_ty = tiled_memref.getType();
tiled_layout =
dyn_cast<tpu::TiledLayoutAttr>(tiled_memref_ty.getLayout());
// TODO(b/375641258): Currently we rely on the pattern `slice ->
// (squeeze)* -> eraseLayout` to get the dynamic sizes, but other patterns
// may not work here: eg., slice -> eraseLayout -> reshape ->
// eraseLayout`. We should fix this! For now, if we can not get the
// expected dynamic sizes, we consider the memref cannot be reinterpreted
// to untiled.
Value ref = tiled_memref;
while (auto squeeze_op = ref.getDefiningOp<tpu::MemRefSqueezeOp>()) {
ref = squeeze_op.getInput();
}
if (auto slice_op = ref.getDefiningOp<tpu::MemRefSliceOp>()) {
dynamic_sizes = slice_op.getDynamicSizes();
}
}
}
if (!tiled_layout) {
// We expect the tiled memref to have a tiled layout.
return false;
}
if (tiled_memref_ty.getNumDynamicDims() != dynamic_sizes.size()) {
return false;
}
if (tiled_layout.getTiles().empty() ||
tiled_layout.getTiles().front().dimensions().size() != 2 ||
tiled_memref_ty.getRank() < 2) {
// TODO(b/375642202): Currently we only support >= 2D memref, we might
// need to handle 1D memref if we find a use case.
return false;
}
auto rank = tiled_memref_ty.getRank();
auto packing = 32 / tiled_memref_ty.getElementTypeBitWidth();
if (tiled_memref_ty.isDynamicDim(rank - 1)) {
// TODO(jevinjiang): we can still allow the minormost padding if we know the
// max bound of the dynamic size is not larger than the target_shape[1].
if (!isGuaranteedDivisible(dynamic_sizes.back(), target_shape[1])) {
return false;
}
dynamic_sizes = dynamic_sizes.drop_back();
} else {
if (!allow_minormost_padding &&
tiled_memref_ty.getShape()[rank - 1] != target_shape[1]) {
return false;
}
}
if (tiled_memref_ty.isDynamicDim(rank - 2)) {
if (!isGuaranteedDivisible(dynamic_sizes.back(), packing)) {
return false;
}
} else {
if (tiled_memref_ty.getShape()[rank - 2] % packing != 0) {
return false;
}
}
// Check if the minormost dim has a single tile.
return *(tiled_layout.getTileStrides().end() - 1) == 1 &&
*(tiled_layout.getTileStrides().end() - 2) == 1;
}
bool HasMemorySpace(MemRefType ty, tpu::MemorySpace space) {
auto memory_space =
dyn_cast_or_null<tpu::MemorySpaceAttr>(ty.getMemorySpace());
return memory_space && memory_space.getValue() == space;
}
bool layoutIsValidForValue(const Layout &l, const Value v,
const std::array<int64_t, 2> target_shape) {
// l must be non-null iff v is of vector type
if (const auto vty = dyn_cast<VectorType>(v.getType())) {
if (!l.has_value()) {
return false;
}
// Vector type should have the same bitwidth as the layout, except for the
// i1 special case, used for vmasks (see comment for VectorLayout class).
if (!vty.getElementType().isIntOrFloat()) {
return false;
}
const int8_t bitwidth = vty.getElementTypeBitWidth();
if (bitwidth != l->bitwidth() && bitwidth != 1) {
return false;
}
return l->isValid(target_shape) && l->layout_rank() <= vty.getRank();
}
return !l.has_value();
}
FailureOr<SmallVector<Layout>> getLayoutArrayFromAttr(const Attribute attr) {
if (const auto array_attr = dyn_cast_if_present<ArrayAttr>(attr)) {
SmallVector<Layout> out_layouts;
out_layouts.reserve(array_attr.size());
for (const Attribute a : array_attr) {
if (auto layout_attr = dyn_cast_if_present<VectorLayoutAttr>(a)) {
out_layouts.push_back(layout_attr.getLayout());
} else {
return failure();
}
}
return out_layouts;
}
return SmallVector<Layout>{};
}
// TODO(tlongeri, jevinjiang): Unify with infer_vector_layout.cc's getOutLayout.
FailureOr<SmallVector<Layout>> getOutLayouts(
Operation &op, const std::array<int64_t, 2> target_shape) {
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> out_layouts,
getLayoutArrayFromAttr(op.getAttr("out_layout")));
if (out_layouts.size() != op.getNumResults()) {
return op.emitOpError("out_layout size does not match number of results");
}
for (const auto [l, res] : llvm::zip_equal(out_layouts, op.getResults())) {
if (!layoutIsValidForValue(l, res, target_shape)) {
return op.emitOpError("Invalid output layout");
}
}
return out_layouts;
}
FailureOr<SmallVector<Layout>> getInLayouts(
Operation &op, const std::array<int64_t, 2> target_shape) {
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> in_layouts,
getLayoutArrayFromAttr(op.getAttr("in_layout")));
if (in_layouts.size() != op.getNumOperands()) {
return op.emitOpError("in_layout size does not match number of operands");
}
for (const auto [l, operand] :
llvm::zip_equal(in_layouts, op.getOperands())) {
if (!layoutIsValidForValue(l, operand, target_shape)) {
return op.emitOpError("Invalid input layout");
}
}
return in_layouts;
}
void setInLayout(Operation *op, ArrayRef<Layout> in) {
CHECK_EQ(in.size(), op->getNumOperands()) << Print(op);
SmallVector<Attribute, 4> in_attrs;
in_attrs.reserve(in.size());
for (const Layout &p : in) {
in_attrs.push_back(VectorLayoutAttr::get(op->getContext(), p));
}
op->setAttr("in_layout", ArrayAttr::get(op->getContext(), in_attrs));
}
void setOutLayout(Operation *op, Layout out) {
setOutLayout(op, ArrayRef<Layout>(out));
}
void setOutLayout(Operation *op, ArrayRef<Layout> out) {
SmallVector<Attribute, 4> out_attrs;
out_attrs.reserve(out.size());
for (const Layout &p : out) {
out_attrs.push_back(VectorLayoutAttr::get(op->getContext(), p));
}
op->setAttr("out_layout", ArrayAttr::get(op->getContext(), out_attrs));
}
void setLayout(Operation *op, Layout in, Layout out) {
setLayout(op, ArrayRef<Layout>(in), ArrayRef<Layout>(out));
}
void setLayout(Operation *op, ArrayRef<Layout> in, Layout out) {
setLayout(op, in, ArrayRef<Layout>(out));
}
void setLayout(Operation *op, Layout in, ArrayRef<Layout> out) {
setLayout(op, ArrayRef<Layout>(in), out);
}
void setLayout(Operation *op, ArrayRef<Layout> in, ArrayRef<Layout> out) {
setInLayout(op, in);
setOutLayout(op, out);
}
} // namespace mlir::tpu