[Mosaic] apply_vector_layout C++ rewrite (4) Elementwise ops

PiperOrigin-RevId: 565255860
This commit is contained in:
Tomás Longeri 2023-09-13 22:18:30 -07:00 committed by jax authors
parent 6869000636
commit 838f59e576
2 changed files with 396 additions and 1 deletions

View File

@ -57,6 +57,7 @@ cc_library(
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",

View File

@ -1,3 +1,4 @@
#include <algorithm>
#include <array>
#include <cstdint>
#include <cstdlib>
@ -14,6 +15,8 @@
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
@ -83,11 +86,402 @@ FailureOr<VectorType> getNativeVregType(
elem_ty);
}
LogicalResult elementwise_op_rule(
RewriteContext &ctx, Operation &op, const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out,
std::function<FailureOr<Operation *>(RewriteContext &, ArrayRef<Value>)>
factory) {
CHECK_EQ(layouts_in.size(), op.getNumOperands());
CHECK_GT(layouts_in.size(), 0);
CHECK_EQ(layouts_out.size(), 1);
if (!(layouts_out.front().has_value() &&
llvm::all_of(layouts_in,
[&](const Layout &l) { return l.has_value(); }))) {
return op.emitOpError("null layout in elementwise operation");
}
const auto vty = cast<VectorType>(op.getResult(0).getType());
const VectorLayout &layout_out = *layouts_out.front();
if (!llvm::all_of(layouts_in, [&](const Layout &l) {
return l->generalizes(layout_out, vty.getShape(), ctx.target_shape);
})) {
return op.emitOpError("incompatible layouts in elementwise operation");
}
const unsigned num_operands = op.getNumOperands();
SmallVector<xla::Array<Value>> in_tile_arrays;
in_tile_arrays.reserve(num_operands);
for (unsigned i = 0; i < num_operands; ++i) {
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> tile_array,
disassemble(ctx, *layouts_in[i], op.getOperand(i)));
in_tile_arrays.emplace_back(std::move(tile_array));
}
// Note that we have to broadcast to handle replicate dimensions.
SmallVector<int64_t> broadcasted_shape(
toArrayRef(in_tile_arrays[0].dimensions()));
for (size_t i = 1; i < num_operands; ++i) {
SmallVector<int64_t> new_broadcasted_shape;
CHECK(OpTrait::util::getBroadcastedShape(
broadcasted_shape, toArrayRef(in_tile_arrays[i].dimensions()),
new_broadcasted_shape));
broadcasted_shape = std::move(new_broadcasted_shape);
}
// TODO(tlongeri): Can we avoid initializing the array before filling values?
xla::Array<Value> out_tile_array(broadcasted_shape);
absl::Status status =
out_tile_array.EachStatus([&](absl::Span<const int64_t> idx, Value *v) {
SmallVector<Value> operands(num_operands);
for (unsigned i = 0; i < num_operands; ++i) {
// Handle indices for broadcasted dimensions
SmallVector<int64_t> operand_idx(toArrayRef(idx));
for (unsigned j = 0; j < idx.size(); ++j) {
if (in_tile_arrays[i].dim(j) == 1) {
operand_idx[j] = 0;
}
}
operands[i] = in_tile_arrays[i](operand_idx);
}
FailureOr<Operation *> failure_or_tile_op = factory(ctx, operands);
if (failed(failure_or_tile_op)) {
return absl::InvalidArgumentError("");
}
Operation *tile_op = *failure_or_tile_op;
CHECK(tile_op);
CHECK_EQ(tile_op->getNumResults(), 1);
*v = tile_op->getResult(0);
return absl::OkStatus();
});
if (!status.ok()) {
return failure();
}
op.replaceAllUsesWith(
assemble(ctx, vty, layout_out, std::move(out_tile_array)));
op.erase();
return success();
}
// Helper for index_sequence expansion
template <typename T, std::size_t>
using Wrapper = T;
template <std::size_t... I>
LogicalResult elementwise_op_rule_unpacked_impl(
RewriteContext &ctx, Operation &op, const ArrayRef<Layout> layout_in,
const ArrayRef<Layout> layout_out,
std::function<FailureOr<Operation *>(RewriteContext &ctx,
Wrapper<Value, I>...)>
factory,
std::index_sequence<I...>) {
return elementwise_op_rule(
ctx, op, layout_in, layout_out,
[&](RewriteContext &ctx,
ArrayRef<Value> operands) -> FailureOr<Operation *> {
if (operands.size() != sizeof...(I)) {
return failure();
}
return factory(ctx, operands[I]...);
});
}
// Like elementwise_op_rule, but operands are "unpacked" into individual
// arguments for the factory.
// Returns failure if the number of operands is not the one expected (i.e. it
// doesn't match NumOperands).
template <std::size_t NumOperands, typename Func>
LogicalResult elementwise_op_rule_unpacked(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out,
Func factory) {
return elementwise_op_rule_unpacked_impl(
ctx, op, layouts_in, layouts_out, std::move(factory),
std::make_index_sequence<NumOperands>());
}
using rule_type = std::function<LogicalResult(
RewriteContext &, Operation &, ArrayRef<Layout>, ArrayRef<Layout>)>;
LogicalResult arith_cmpf_rule(RewriteContext &ctx, Operation &op,
ArrayRef<Layout> layouts_in,
ArrayRef<Layout> layouts_out) {
auto cmpf_op = cast<arith::CmpFOp>(op);
return elementwise_op_rule_unpacked<2>(
ctx, op, layouts_in, layouts_out,
[&](RewriteContext &ctx, const Value lhs,
const Value rhs) -> FailureOr<Operation *> {
return ctx.builder
.create<arith::CmpFOp>(cmpf_op.getLoc(), cmpf_op.getPredicateAttr(),
lhs, rhs)
.getOperation();
});
}
LogicalResult arith_cmpi_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
auto cmpi_op = cast<arith::CmpIOp>(op);
return elementwise_op_rule_unpacked<2>(
ctx, op, layouts_in, layouts_out,
[&](RewriteContext &ctx, const Value lhs,
const Value rhs) -> FailureOr<Operation *> {
return ctx.builder
.create<arith::CmpIOp>(cmpi_op.getLoc(), cmpi_op.getPredicateAttr(),
lhs, rhs)
.getOperation();
});
}
LogicalResult arith_extui_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
auto extui_op = cast<arith::ExtUIOp>(op);
const Type elem_ty =
cast<VectorType>(extui_op.getResult().getType()).getElementType();
return elementwise_op_rule_unpacked<1>(
ctx, op, layouts_in, layouts_out,
[&](RewriteContext &ctx, const Value x) -> FailureOr<Operation *> {
const VectorType x_ty = cast<VectorType>(x.getType());
const VectorType out_ty = VectorType::get(x_ty.getShape(), elem_ty);
return ctx.builder.create<arith::ExtUIOp>(extui_op.getLoc(), out_ty, x)
.getOperation();
});
}
template <typename OpTy>
LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
const VectorLayout &layout_in,
const VectorLayout &layout_out) {
auto result_ty = cast<VectorType>(op.getResult().getType());
if (layout_out.bitwidth() != 32) {
return op.emitOpError("Only extensions to 32-bit supported");
}
FAILUREOR_ASSIGN_OR_RETURN(const xla::Array<Value> input_vregs,
disassemble(ctx, layout_in, op.getIn()));
xla::Array<Value> output_vregs(
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
FAILUREOR_ASSIGN_OR_RETURN(
const VectorType res_vreg_ty,
getNativeVregType(result_ty.getElementType(), ctx.target_shape));
if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
return op.emitOpError("Not implemented: Change of layout during the cast");
}
switch (layout_in.implicit_dim()) {
case VectorLayout::ImplicitDim::kNone: {
if (layout_in.tiling() != ctx.target_shape ||
layout_out.tiling() != ctx.target_shape) {
return op.emitOpError("Not implemented: tiling not supported");
}
const int packing = layout_in.packing();
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
SmallVector<int64_t> input_vreg_idxs(toArrayRef(idxs));
input_vreg_idxs.back() /= packing;
const int64_t vreg_part = idxs.back() % packing;
*v = ctx.builder.create<UnpackSubelementsOp>(
op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part);
});
} break;
case VectorLayout::ImplicitDim::kMinor:
return op.emitOpError(
"Not implemented: Only casts of lane-oriented values supported");
case VectorLayout::ImplicitDim::kSecondMinor: {
if (input_vregs.dimensions() != absl::Span<const int64_t>{1} ||
output_vregs.dimensions() != absl::Span<const int64_t>{1}) {
return op.emitOpError("Not implemented");
}
if (layout_in.offsets()[0] >= ctx.target_shape[0]) {
return op.emitOpError("Not implemented");
}
auto unpack_subelements_op = ctx.builder.create<UnpackSubelementsOp>(
op.getLoc(), res_vreg_ty, *input_vregs.begin(), 0);
output_vregs.Fill(unpack_subelements_op.getResult());
}
}
op.replaceAllUsesWith(
assemble(ctx, result_ty, layout_out, std::move(output_vregs))
.getResult());
op.erase();
return success();
}
LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK(layouts_in.front().has_value());
CHECK(layouts_out.front().has_value());
auto extf_op = cast<arith::ExtFOp>(op);
if (layouts_in.front()->bitwidth() != 32 ||
layouts_out.front()->bitwidth() != 32) {
return op.emitOpError("Only 16-bit to 32-bit conversion supported");
}
return ext_op_rule_impl(ctx, extf_op, *layouts_in.front(),
*layouts_out.front());
}
LogicalResult arith_extsi_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK(layouts_in.front().has_value());
CHECK_EQ(layouts_out.size(), 1);
CHECK(layouts_out.front().has_value());
auto extsi_op = cast<arith::ExtSIOp>(op);
return ext_op_rule_impl(ctx, extsi_op, *layouts_in.front(),
*layouts_out.front());
}
template <typename OpTy>
LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
const VectorLayout &layout_in,
const VectorLayout &layout_out) {
auto result_ty = cast<VectorType>(op.getResult().getType());
FAILUREOR_ASSIGN_OR_RETURN(const xla::Array<Value> input_vregs,
disassemble(ctx, layout_in, op.getIn()));
xla::Array<Value> output_vregs(
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
if (layout_in.bitwidth() != 32) {
return op.emitOpError("Only 32-bit truncation supported");
}
FAILUREOR_ASSIGN_OR_RETURN(
VectorType res_vreg_ty,
getNativeVregType(result_ty.getElementType(), ctx.target_shape));
if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone) {
if (layout_in.tiling() != ctx.target_shape) {
return op.emitOpError("Not implemented: Only (8,128) tiling supported");
}
if (layout_out.tiling() == ctx.target_shape) {
const int packing = layout_out.packing();
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
SmallVector<Value> parts;
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
idxs_local.back() *= packing;
for (int64_t i = 0; i < packing; ++i) {
parts.push_back(input_vregs(idxs_local));
// Pack any data lying around if OOB
if (idxs_local.back() < input_vregs.dimensions().back() - 1) {
++idxs_local.back();
}
}
*v = ctx.builder.create<PackSubelementsOp>(op.getLoc(), res_vreg_ty,
parts);
});
} else if (layout_out.bitwidth() == 16 &&
layout_out.tiling() ==
std::array<int64_t, 2>{2 * ctx.target_shape[0],
ctx.target_shape[1]}) {
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
// TODO(tlongeri): should probably express as a multiple of target_shape
// instead of (16, 128)
CHECK_GE(idxs.size(), 2);
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
idxs_local[idxs.size() - 2] *= 2;
const Value first = input_vregs(idxs_local);
Value second;
if (idxs[idxs.size() - 2] * 2 + 1 ==
input_vregs.dim(input_vregs.num_dimensions() - 2)) {
second = first;
} else {
idxs_local[idxs.size() - 2] += 1;
second = input_vregs(idxs_local);
}
*v = ctx.builder.create<PackSubelementsOp>(
op.getLoc(), res_vreg_ty, ArrayRef<Value>{first, second});
});
} else {
return op.emitOpError("Not implemented");
}
op.replaceAllUsesWith(
assemble(ctx, result_ty, layout_out, std::move(output_vregs))
.getResult());
op.erase();
return success();
}
// TODO(tlongeri): why wasn't this part of the original code?
return op.emitOpError("Not implemented");
}
LogicalResult arith_truncf_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK(layouts_in.front().has_value());
CHECK_EQ(layouts_out.size(), 1);
CHECK(layouts_out.front().has_value());
auto truncf_op = cast<arith::TruncFOp>(op);
if (layouts_in.front()->bitwidth() != 32 ||
layouts_out.front()->bitwidth() != 16) {
return op.emitOpError(
"Not implemented: Only 32-bit to 16-bit conversion supported");
}
return trunc_op_rule_impl(ctx, truncf_op, *layouts_in.front(),
*layouts_out.front());
}
LogicalResult arith_trunci_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK(layouts_in.front().has_value());
CHECK_EQ(layouts_out.size(), 1);
CHECK(layouts_out.front().has_value());
auto trunci_op = cast<arith::TruncIOp>(op);
return trunc_op_rule_impl(ctx, trunci_op, *layouts_in.front(),
*layouts_out.front());
}
template <typename Op, std::size_t NumOperands>
std::pair<StringRef, rule_type> rules_elementwise_op_entry() {
return {
Op::getOperationName(),
[](RewriteContext &ctx, Operation &op, const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) -> LogicalResult {
return elementwise_op_rule_unpacked<NumOperands>(
ctx, op, layouts_in, layouts_out,
[&](RewriteContext &ctx,
auto... operands) -> FailureOr<Operation *> {
return ctx.builder.create<Op>(op.getLoc(), operands...)
.getOperation();
});
}};
}
const llvm::StringMap<rule_type> &rules() {
static auto rules = new llvm::StringMap<rule_type>{};
static auto rules = new llvm::StringMap<rule_type>{
rules_elementwise_op_entry<arith::AddFOp, 2>(),
rules_elementwise_op_entry<arith::AddIOp, 2>(),
{arith::CmpFOp::getOperationName(), arith_cmpf_rule},
{arith::CmpIOp::getOperationName(), arith_cmpi_rule},
{arith::ExtFOp::getOperationName(), arith_extf_rule},
{arith::ExtSIOp::getOperationName(), arith_extsi_rule},
{arith::ExtUIOp::getOperationName(), arith_extui_rule},
{arith::TruncFOp::getOperationName(), arith_truncf_rule},
{arith::TruncIOp::getOperationName(), arith_trunci_rule},
rules_elementwise_op_entry<arith::SubFOp, 2>(),
rules_elementwise_op_entry<arith::SubIOp, 2>(),
rules_elementwise_op_entry<arith::MulFOp, 2>(),
rules_elementwise_op_entry<arith::MulIOp, 2>(),
rules_elementwise_op_entry<arith::DivFOp, 2>(),
rules_elementwise_op_entry<arith::DivSIOp, 2>(),
rules_elementwise_op_entry<arith::RemSIOp, 2>(),
rules_elementwise_op_entry<arith::MaximumFOp, 2>(),
rules_elementwise_op_entry<arith::MinimumFOp, 2>(),
rules_elementwise_op_entry<arith::SelectOp, 3>(),
// TODO(tlongeri) arith::IndexCastOp
rules_elementwise_op_entry<arith::AndIOp, 2>(),
rules_elementwise_op_entry<arith::OrIOp, 2>(),
rules_elementwise_op_entry<arith::NegFOp, 1>(),
rules_elementwise_op_entry<arith::XOrIOp, 2>(),
rules_elementwise_op_entry<arith::ShLIOp, 2>(),
rules_elementwise_op_entry<arith::ShRUIOp, 2>(),
rules_elementwise_op_entry<math::ExpOp, 1>(),
rules_elementwise_op_entry<math::CosOp, 1>(),
rules_elementwise_op_entry<math::SinOp, 1>(),
rules_elementwise_op_entry<math::PowFOp, 1>(),
rules_elementwise_op_entry<math::RsqrtOp, 1>(),
rules_elementwise_op_entry<math::TanhOp, 1>(),
};
return *rules;
}
} // namespace