mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
[Mosaic] apply_vector_layout C++ rewrite (4) Elementwise ops
PiperOrigin-RevId: 565255860
This commit is contained in:
parent
6869000636
commit
838f59e576
@ -57,6 +57,7 @@ cc_library(
|
||||
"@llvm-project//mlir:FuncDialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:MathDialect",
|
||||
"@llvm-project//mlir:MemRefDialect",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
|
@ -1,3 +1,4 @@
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
@ -14,6 +15,8 @@
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
@ -83,11 +86,402 @@ FailureOr<VectorType> getNativeVregType(
|
||||
elem_ty);
|
||||
}
|
||||
|
||||
LogicalResult elementwise_op_rule(
|
||||
RewriteContext &ctx, Operation &op, const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out,
|
||||
std::function<FailureOr<Operation *>(RewriteContext &, ArrayRef<Value>)>
|
||||
factory) {
|
||||
CHECK_EQ(layouts_in.size(), op.getNumOperands());
|
||||
CHECK_GT(layouts_in.size(), 0);
|
||||
CHECK_EQ(layouts_out.size(), 1);
|
||||
if (!(layouts_out.front().has_value() &&
|
||||
llvm::all_of(layouts_in,
|
||||
[&](const Layout &l) { return l.has_value(); }))) {
|
||||
return op.emitOpError("null layout in elementwise operation");
|
||||
}
|
||||
const auto vty = cast<VectorType>(op.getResult(0).getType());
|
||||
const VectorLayout &layout_out = *layouts_out.front();
|
||||
if (!llvm::all_of(layouts_in, [&](const Layout &l) {
|
||||
return l->generalizes(layout_out, vty.getShape(), ctx.target_shape);
|
||||
})) {
|
||||
return op.emitOpError("incompatible layouts in elementwise operation");
|
||||
}
|
||||
const unsigned num_operands = op.getNumOperands();
|
||||
SmallVector<xla::Array<Value>> in_tile_arrays;
|
||||
in_tile_arrays.reserve(num_operands);
|
||||
for (unsigned i = 0; i < num_operands; ++i) {
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
xla::Array<Value> tile_array,
|
||||
disassemble(ctx, *layouts_in[i], op.getOperand(i)));
|
||||
in_tile_arrays.emplace_back(std::move(tile_array));
|
||||
}
|
||||
|
||||
// Note that we have to broadcast to handle replicate dimensions.
|
||||
SmallVector<int64_t> broadcasted_shape(
|
||||
toArrayRef(in_tile_arrays[0].dimensions()));
|
||||
for (size_t i = 1; i < num_operands; ++i) {
|
||||
SmallVector<int64_t> new_broadcasted_shape;
|
||||
CHECK(OpTrait::util::getBroadcastedShape(
|
||||
broadcasted_shape, toArrayRef(in_tile_arrays[i].dimensions()),
|
||||
new_broadcasted_shape));
|
||||
broadcasted_shape = std::move(new_broadcasted_shape);
|
||||
}
|
||||
|
||||
// TODO(tlongeri): Can we avoid initializing the array before filling values?
|
||||
xla::Array<Value> out_tile_array(broadcasted_shape);
|
||||
absl::Status status =
|
||||
out_tile_array.EachStatus([&](absl::Span<const int64_t> idx, Value *v) {
|
||||
SmallVector<Value> operands(num_operands);
|
||||
for (unsigned i = 0; i < num_operands; ++i) {
|
||||
// Handle indices for broadcasted dimensions
|
||||
SmallVector<int64_t> operand_idx(toArrayRef(idx));
|
||||
for (unsigned j = 0; j < idx.size(); ++j) {
|
||||
if (in_tile_arrays[i].dim(j) == 1) {
|
||||
operand_idx[j] = 0;
|
||||
}
|
||||
}
|
||||
operands[i] = in_tile_arrays[i](operand_idx);
|
||||
}
|
||||
FailureOr<Operation *> failure_or_tile_op = factory(ctx, operands);
|
||||
if (failed(failure_or_tile_op)) {
|
||||
return absl::InvalidArgumentError("");
|
||||
}
|
||||
Operation *tile_op = *failure_or_tile_op;
|
||||
CHECK(tile_op);
|
||||
CHECK_EQ(tile_op->getNumResults(), 1);
|
||||
*v = tile_op->getResult(0);
|
||||
return absl::OkStatus();
|
||||
});
|
||||
if (!status.ok()) {
|
||||
return failure();
|
||||
}
|
||||
op.replaceAllUsesWith(
|
||||
assemble(ctx, vty, layout_out, std::move(out_tile_array)));
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Helper for index_sequence expansion
|
||||
template <typename T, std::size_t>
|
||||
using Wrapper = T;
|
||||
|
||||
template <std::size_t... I>
|
||||
LogicalResult elementwise_op_rule_unpacked_impl(
|
||||
RewriteContext &ctx, Operation &op, const ArrayRef<Layout> layout_in,
|
||||
const ArrayRef<Layout> layout_out,
|
||||
std::function<FailureOr<Operation *>(RewriteContext &ctx,
|
||||
Wrapper<Value, I>...)>
|
||||
factory,
|
||||
std::index_sequence<I...>) {
|
||||
return elementwise_op_rule(
|
||||
ctx, op, layout_in, layout_out,
|
||||
[&](RewriteContext &ctx,
|
||||
ArrayRef<Value> operands) -> FailureOr<Operation *> {
|
||||
if (operands.size() != sizeof...(I)) {
|
||||
return failure();
|
||||
}
|
||||
return factory(ctx, operands[I]...);
|
||||
});
|
||||
}
|
||||
|
||||
// Like elementwise_op_rule, but operands are "unpacked" into individual
|
||||
// arguments for the factory.
|
||||
// Returns failure if the number of operands is not the one expected (i.e. it
|
||||
// doesn't match NumOperands).
|
||||
template <std::size_t NumOperands, typename Func>
|
||||
LogicalResult elementwise_op_rule_unpacked(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out,
|
||||
Func factory) {
|
||||
return elementwise_op_rule_unpacked_impl(
|
||||
ctx, op, layouts_in, layouts_out, std::move(factory),
|
||||
std::make_index_sequence<NumOperands>());
|
||||
}
|
||||
|
||||
using rule_type = std::function<LogicalResult(
|
||||
RewriteContext &, Operation &, ArrayRef<Layout>, ArrayRef<Layout>)>;
|
||||
|
||||
LogicalResult arith_cmpf_rule(RewriteContext &ctx, Operation &op,
|
||||
ArrayRef<Layout> layouts_in,
|
||||
ArrayRef<Layout> layouts_out) {
|
||||
auto cmpf_op = cast<arith::CmpFOp>(op);
|
||||
return elementwise_op_rule_unpacked<2>(
|
||||
ctx, op, layouts_in, layouts_out,
|
||||
[&](RewriteContext &ctx, const Value lhs,
|
||||
const Value rhs) -> FailureOr<Operation *> {
|
||||
return ctx.builder
|
||||
.create<arith::CmpFOp>(cmpf_op.getLoc(), cmpf_op.getPredicateAttr(),
|
||||
lhs, rhs)
|
||||
.getOperation();
|
||||
});
|
||||
}
|
||||
|
||||
LogicalResult arith_cmpi_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
auto cmpi_op = cast<arith::CmpIOp>(op);
|
||||
return elementwise_op_rule_unpacked<2>(
|
||||
ctx, op, layouts_in, layouts_out,
|
||||
[&](RewriteContext &ctx, const Value lhs,
|
||||
const Value rhs) -> FailureOr<Operation *> {
|
||||
return ctx.builder
|
||||
.create<arith::CmpIOp>(cmpi_op.getLoc(), cmpi_op.getPredicateAttr(),
|
||||
lhs, rhs)
|
||||
.getOperation();
|
||||
});
|
||||
}
|
||||
|
||||
LogicalResult arith_extui_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
auto extui_op = cast<arith::ExtUIOp>(op);
|
||||
const Type elem_ty =
|
||||
cast<VectorType>(extui_op.getResult().getType()).getElementType();
|
||||
return elementwise_op_rule_unpacked<1>(
|
||||
ctx, op, layouts_in, layouts_out,
|
||||
[&](RewriteContext &ctx, const Value x) -> FailureOr<Operation *> {
|
||||
const VectorType x_ty = cast<VectorType>(x.getType());
|
||||
const VectorType out_ty = VectorType::get(x_ty.getShape(), elem_ty);
|
||||
return ctx.builder.create<arith::ExtUIOp>(extui_op.getLoc(), out_ty, x)
|
||||
.getOperation();
|
||||
});
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
|
||||
const VectorLayout &layout_in,
|
||||
const VectorLayout &layout_out) {
|
||||
auto result_ty = cast<VectorType>(op.getResult().getType());
|
||||
if (layout_out.bitwidth() != 32) {
|
||||
return op.emitOpError("Only extensions to 32-bit supported");
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const xla::Array<Value> input_vregs,
|
||||
disassemble(ctx, layout_in, op.getIn()));
|
||||
xla::Array<Value> output_vregs(
|
||||
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const VectorType res_vreg_ty,
|
||||
getNativeVregType(result_ty.getElementType(), ctx.target_shape));
|
||||
if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
|
||||
return op.emitOpError("Not implemented: Change of layout during the cast");
|
||||
}
|
||||
switch (layout_in.implicit_dim()) {
|
||||
case VectorLayout::ImplicitDim::kNone: {
|
||||
if (layout_in.tiling() != ctx.target_shape ||
|
||||
layout_out.tiling() != ctx.target_shape) {
|
||||
return op.emitOpError("Not implemented: tiling not supported");
|
||||
}
|
||||
const int packing = layout_in.packing();
|
||||
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
SmallVector<int64_t> input_vreg_idxs(toArrayRef(idxs));
|
||||
input_vreg_idxs.back() /= packing;
|
||||
const int64_t vreg_part = idxs.back() % packing;
|
||||
*v = ctx.builder.create<UnpackSubelementsOp>(
|
||||
op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part);
|
||||
});
|
||||
} break;
|
||||
case VectorLayout::ImplicitDim::kMinor:
|
||||
return op.emitOpError(
|
||||
"Not implemented: Only casts of lane-oriented values supported");
|
||||
case VectorLayout::ImplicitDim::kSecondMinor: {
|
||||
if (input_vregs.dimensions() != absl::Span<const int64_t>{1} ||
|
||||
output_vregs.dimensions() != absl::Span<const int64_t>{1}) {
|
||||
return op.emitOpError("Not implemented");
|
||||
}
|
||||
if (layout_in.offsets()[0] >= ctx.target_shape[0]) {
|
||||
return op.emitOpError("Not implemented");
|
||||
}
|
||||
auto unpack_subelements_op = ctx.builder.create<UnpackSubelementsOp>(
|
||||
op.getLoc(), res_vreg_ty, *input_vregs.begin(), 0);
|
||||
output_vregs.Fill(unpack_subelements_op.getResult());
|
||||
}
|
||||
}
|
||||
op.replaceAllUsesWith(
|
||||
assemble(ctx, result_ty, layout_out, std::move(output_vregs))
|
||||
.getResult());
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
CHECK_EQ(layouts_in.size(), 1);
|
||||
CHECK(layouts_in.front().has_value());
|
||||
CHECK(layouts_out.front().has_value());
|
||||
auto extf_op = cast<arith::ExtFOp>(op);
|
||||
if (layouts_in.front()->bitwidth() != 32 ||
|
||||
layouts_out.front()->bitwidth() != 32) {
|
||||
return op.emitOpError("Only 16-bit to 32-bit conversion supported");
|
||||
}
|
||||
return ext_op_rule_impl(ctx, extf_op, *layouts_in.front(),
|
||||
*layouts_out.front());
|
||||
}
|
||||
|
||||
LogicalResult arith_extsi_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
CHECK_EQ(layouts_in.size(), 1);
|
||||
CHECK(layouts_in.front().has_value());
|
||||
CHECK_EQ(layouts_out.size(), 1);
|
||||
CHECK(layouts_out.front().has_value());
|
||||
auto extsi_op = cast<arith::ExtSIOp>(op);
|
||||
return ext_op_rule_impl(ctx, extsi_op, *layouts_in.front(),
|
||||
*layouts_out.front());
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
|
||||
const VectorLayout &layout_in,
|
||||
const VectorLayout &layout_out) {
|
||||
auto result_ty = cast<VectorType>(op.getResult().getType());
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const xla::Array<Value> input_vregs,
|
||||
disassemble(ctx, layout_in, op.getIn()));
|
||||
xla::Array<Value> output_vregs(
|
||||
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
|
||||
if (layout_in.bitwidth() != 32) {
|
||||
return op.emitOpError("Only 32-bit truncation supported");
|
||||
}
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
VectorType res_vreg_ty,
|
||||
getNativeVregType(result_ty.getElementType(), ctx.target_shape));
|
||||
if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
|
||||
layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone) {
|
||||
if (layout_in.tiling() != ctx.target_shape) {
|
||||
return op.emitOpError("Not implemented: Only (8,128) tiling supported");
|
||||
}
|
||||
if (layout_out.tiling() == ctx.target_shape) {
|
||||
const int packing = layout_out.packing();
|
||||
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
SmallVector<Value> parts;
|
||||
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
|
||||
idxs_local.back() *= packing;
|
||||
for (int64_t i = 0; i < packing; ++i) {
|
||||
parts.push_back(input_vregs(idxs_local));
|
||||
// Pack any data lying around if OOB
|
||||
if (idxs_local.back() < input_vregs.dimensions().back() - 1) {
|
||||
++idxs_local.back();
|
||||
}
|
||||
}
|
||||
*v = ctx.builder.create<PackSubelementsOp>(op.getLoc(), res_vreg_ty,
|
||||
parts);
|
||||
});
|
||||
|
||||
} else if (layout_out.bitwidth() == 16 &&
|
||||
layout_out.tiling() ==
|
||||
std::array<int64_t, 2>{2 * ctx.target_shape[0],
|
||||
ctx.target_shape[1]}) {
|
||||
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
|
||||
// TODO(tlongeri): should probably express as a multiple of target_shape
|
||||
// instead of (16, 128)
|
||||
CHECK_GE(idxs.size(), 2);
|
||||
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
|
||||
idxs_local[idxs.size() - 2] *= 2;
|
||||
const Value first = input_vregs(idxs_local);
|
||||
Value second;
|
||||
if (idxs[idxs.size() - 2] * 2 + 1 ==
|
||||
input_vregs.dim(input_vregs.num_dimensions() - 2)) {
|
||||
second = first;
|
||||
} else {
|
||||
idxs_local[idxs.size() - 2] += 1;
|
||||
second = input_vregs(idxs_local);
|
||||
}
|
||||
*v = ctx.builder.create<PackSubelementsOp>(
|
||||
op.getLoc(), res_vreg_ty, ArrayRef<Value>{first, second});
|
||||
});
|
||||
} else {
|
||||
return op.emitOpError("Not implemented");
|
||||
}
|
||||
op.replaceAllUsesWith(
|
||||
assemble(ctx, result_ty, layout_out, std::move(output_vregs))
|
||||
.getResult());
|
||||
op.erase();
|
||||
return success();
|
||||
}
|
||||
// TODO(tlongeri): why wasn't this part of the original code?
|
||||
return op.emitOpError("Not implemented");
|
||||
}
|
||||
|
||||
LogicalResult arith_truncf_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
CHECK_EQ(layouts_in.size(), 1);
|
||||
CHECK(layouts_in.front().has_value());
|
||||
CHECK_EQ(layouts_out.size(), 1);
|
||||
CHECK(layouts_out.front().has_value());
|
||||
auto truncf_op = cast<arith::TruncFOp>(op);
|
||||
if (layouts_in.front()->bitwidth() != 32 ||
|
||||
layouts_out.front()->bitwidth() != 16) {
|
||||
return op.emitOpError(
|
||||
"Not implemented: Only 32-bit to 16-bit conversion supported");
|
||||
}
|
||||
return trunc_op_rule_impl(ctx, truncf_op, *layouts_in.front(),
|
||||
*layouts_out.front());
|
||||
}
|
||||
|
||||
LogicalResult arith_trunci_rule(RewriteContext &ctx, Operation &op,
|
||||
const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) {
|
||||
CHECK_EQ(layouts_in.size(), 1);
|
||||
CHECK(layouts_in.front().has_value());
|
||||
CHECK_EQ(layouts_out.size(), 1);
|
||||
CHECK(layouts_out.front().has_value());
|
||||
auto trunci_op = cast<arith::TruncIOp>(op);
|
||||
return trunc_op_rule_impl(ctx, trunci_op, *layouts_in.front(),
|
||||
*layouts_out.front());
|
||||
}
|
||||
|
||||
template <typename Op, std::size_t NumOperands>
|
||||
std::pair<StringRef, rule_type> rules_elementwise_op_entry() {
|
||||
return {
|
||||
Op::getOperationName(),
|
||||
[](RewriteContext &ctx, Operation &op, const ArrayRef<Layout> layouts_in,
|
||||
const ArrayRef<Layout> layouts_out) -> LogicalResult {
|
||||
return elementwise_op_rule_unpacked<NumOperands>(
|
||||
ctx, op, layouts_in, layouts_out,
|
||||
[&](RewriteContext &ctx,
|
||||
auto... operands) -> FailureOr<Operation *> {
|
||||
return ctx.builder.create<Op>(op.getLoc(), operands...)
|
||||
.getOperation();
|
||||
});
|
||||
}};
|
||||
}
|
||||
|
||||
const llvm::StringMap<rule_type> &rules() {
|
||||
static auto rules = new llvm::StringMap<rule_type>{};
|
||||
static auto rules = new llvm::StringMap<rule_type>{
|
||||
rules_elementwise_op_entry<arith::AddFOp, 2>(),
|
||||
rules_elementwise_op_entry<arith::AddIOp, 2>(),
|
||||
{arith::CmpFOp::getOperationName(), arith_cmpf_rule},
|
||||
{arith::CmpIOp::getOperationName(), arith_cmpi_rule},
|
||||
{arith::ExtFOp::getOperationName(), arith_extf_rule},
|
||||
{arith::ExtSIOp::getOperationName(), arith_extsi_rule},
|
||||
{arith::ExtUIOp::getOperationName(), arith_extui_rule},
|
||||
{arith::TruncFOp::getOperationName(), arith_truncf_rule},
|
||||
{arith::TruncIOp::getOperationName(), arith_trunci_rule},
|
||||
rules_elementwise_op_entry<arith::SubFOp, 2>(),
|
||||
rules_elementwise_op_entry<arith::SubIOp, 2>(),
|
||||
rules_elementwise_op_entry<arith::MulFOp, 2>(),
|
||||
rules_elementwise_op_entry<arith::MulIOp, 2>(),
|
||||
rules_elementwise_op_entry<arith::DivFOp, 2>(),
|
||||
rules_elementwise_op_entry<arith::DivSIOp, 2>(),
|
||||
rules_elementwise_op_entry<arith::RemSIOp, 2>(),
|
||||
rules_elementwise_op_entry<arith::MaximumFOp, 2>(),
|
||||
rules_elementwise_op_entry<arith::MinimumFOp, 2>(),
|
||||
rules_elementwise_op_entry<arith::SelectOp, 3>(),
|
||||
// TODO(tlongeri) arith::IndexCastOp
|
||||
rules_elementwise_op_entry<arith::AndIOp, 2>(),
|
||||
rules_elementwise_op_entry<arith::OrIOp, 2>(),
|
||||
rules_elementwise_op_entry<arith::NegFOp, 1>(),
|
||||
rules_elementwise_op_entry<arith::XOrIOp, 2>(),
|
||||
rules_elementwise_op_entry<arith::ShLIOp, 2>(),
|
||||
rules_elementwise_op_entry<arith::ShRUIOp, 2>(),
|
||||
rules_elementwise_op_entry<math::ExpOp, 1>(),
|
||||
rules_elementwise_op_entry<math::CosOp, 1>(),
|
||||
rules_elementwise_op_entry<math::SinOp, 1>(),
|
||||
rules_elementwise_op_entry<math::PowFOp, 1>(),
|
||||
rules_elementwise_op_entry<math::RsqrtOp, 1>(),
|
||||
rules_elementwise_op_entry<math::TanhOp, 1>(),
|
||||
};
|
||||
return *rules;
|
||||
}
|
||||
} // namespace
|
||||
|
Loading…
x
Reference in New Issue
Block a user