[Mosaic] apply_vector_layout C++ rewrite: various bug fixes

PiperOrigin-RevId: 571075082
This commit is contained in:
Tomás Longeri 2023-10-05 11:10:15 -07:00 committed by jax authors
parent 295cecd505
commit ab4a8e3417
4 changed files with 34 additions and 24 deletions

View File

@ -276,9 +276,13 @@ class TiledRectangularVregBounds : public VRegDataBounds {
return VectorType::get(target_shape, i1);
}());
if (isComplete(target_shape)) {
builder.create<arith::ConstantOp>(
loc, mask_vreg_ty,
DenseElementsAttr::get(mask_vreg_ty, builder.getBoolAttr(true)));
return cast<TypedValue<VectorType>>(
builder
.create<arith::ConstantOp>(
loc, mask_vreg_ty,
DenseElementsAttr::get(mask_vreg_ty,
builder.getBoolAttr(true)))
.getResult());
}
Value mask = nullptr;
CHECK_GE(num_tiles_, 0);
@ -488,7 +492,7 @@ llvm::SmallVector<int64_t> VectorLayout::tileArrayShape(
tiles_shape.pop_back();
break;
case ImplicitDim::kSecondMinor:
tiles_shape.erase(tiles_shape.end() - 1);
tiles_shape.erase(tiles_shape.end() - 2);
break;
}
return tiles_shape;

View File

@ -60,9 +60,9 @@ struct VRegDataBounds {
std::array<int64_t, 2> target_shape) const = 0;
bool isComplete(const std::array<int64_t, 2> target_shape) const {
return maskVariesAlong(Direction::kSublanes, target_shape) ||
maskVariesAlong(Direction::kLanes, target_shape) ||
maskVariesAlong(Direction::kSubelements, target_shape);
return !maskVariesAlong(Direction::kSublanes, target_shape) &&
!maskVariesAlong(Direction::kLanes, target_shape) &&
!maskVariesAlong(Direction::kSubelements, target_shape);
}
// Constructs a vector mask value that is true iff the entry contains useful

View File

@ -83,7 +83,8 @@ struct RewriteContext {
LogicalResult applyLayoutBlock(RewriteContext &ctx, Block &block);
RollVectorsOp assemble(RewriteContext &ctx, VectorType vty,
const VectorLayout &layout, xla::Array<Value> vals);
const VectorLayout &layout,
const xla::Array<Value> &vals);
FailureOr<xla::Array<Value>> disassemble(RewriteContext &ctx,
const VectorLayout &layout, Value val);
namespace {
@ -319,6 +320,9 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
vector_constants.push_back(value);
ctx.func->setAttr("vector_constants",
ArrayAttr::get(ctx.func.getContext(), vector_constants));
} else {
ctx.func->setAttr("vector_constants",
ArrayAttr::get(ctx.func.getContext(), value));
}
// Adjust window params for the extra operand.
if (auto window_params =
@ -337,7 +341,7 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
StringAttr::get(ctx.func.getContext(), "transform_indices"),
AffineMapAttr::get(transform_indices)));
SmallVector<Attribute> window_params_values(window_params.getValue());
window_params_values.push_back(new_param);
window_params_values.insert(window_params_values.end() - 1, new_param);
ctx.func->setAttr("window_params", ArrayAttr::get(ctx.func.getContext(),
window_params_values));
}
@ -650,7 +654,7 @@ LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op,
CHECK(layouts_in.front().has_value());
CHECK(layouts_out.front().has_value());
auto extf_op = cast<arith::ExtFOp>(op);
if (layouts_in.front()->bitwidth() != 32 ||
if (layouts_in.front()->bitwidth() != 16 ||
layouts_out.front()->bitwidth() != 32) {
return op.emitOpError("Only 16-bit to 32-bit conversion supported");
}
@ -2526,7 +2530,7 @@ const llvm::StringMap<rule_type> &rules() {
rules_elementwise_op_entry<math::ExpOp, 1>(),
rules_elementwise_op_entry<math::CosOp, 1>(),
rules_elementwise_op_entry<math::SinOp, 1>(),
rules_elementwise_op_entry<math::PowFOp, 1>(),
rules_elementwise_op_entry<math::PowFOp, 2>(),
rules_elementwise_op_entry<math::RsqrtOp, 1>(),
rules_elementwise_op_entry<math::TanhOp, 1>(),
{func::ReturnOp::getOperationName(), func_return_rule},
@ -2554,7 +2558,7 @@ const llvm::StringMap<rule_type> &rules() {
RollVectorsOp assemble(RewriteContext &ctx, VectorType vty,
const VectorLayout &layout,
const xla::Array<Value> vals) {
const xla::Array<Value> &vals) {
CHECK(vals.dimensions() ==
layout.tileArrayShape(vty.getShape(), ctx.target_shape));
CHECK_GT(vals.num_elements(), 0);
@ -2950,11 +2954,11 @@ FailureOr<Value> relayout(RewriteContext &ctx, Value v, VectorLayout src,
disassemble(ctx, src, v));
SmallVector<int64_t> dst_tiles_shape =
dst.tileArrayShape(vty.getShape(), ctx.target_shape);
if (src.generalizes(dst, vty.getShape(), ctx.target_shape) &&
src.tilesPerVreg(ctx.target_shape) == 1) {
if (src.generalizes(dst, vty.getShape(), ctx.target_shape)) {
return assemble(ctx, vty, dst, std::move(src_tiles)).getResult();
}
if (!src.offsets()[0].has_value() && !src.offsets()[1].has_value()) {
if (!src.offsets()[0].has_value() && !src.offsets()[1].has_value() &&
src.tilesPerVreg(ctx.target_shape) == 1) {
// A fully replicated value is always easy to relayout
// It would be nice to be able to assert this here, but given replicated
// values our rules can introduce equivalent expressions.
@ -3057,12 +3061,12 @@ FailureOr<Value> relayout(RewriteContext &ctx, Value v, VectorLayout src,
src_idx[src_idx.size() - 2] *= 4;
src_idx[src_idx.size() - 1] /= 4;
for (int i = 0; i < 4; ++i) {
parts[i] = ctx.builder.create<tpu::UnpackSubelementsOp>(
v.getLoc(), vreg_i32, src_tiles(src_idx), vreg_part);
if (src_idx[src_idx.size() - 2] <
src_tiles.dim(src_tiles.num_dimensions() - 2) - 1) {
++src_idx[src_idx.size() - 2];
}
parts[i] = ctx.builder.create<tpu::UnpackSubelementsOp>(
v.getLoc(), vreg_i32, src_tiles(src_idx), vreg_part);
}
*tile = ctx.builder.create<tpu::PackSubelementsOp>(
v.getLoc(), src_tiles.begin()->getType(), parts);
@ -3117,8 +3121,8 @@ FailureOr<Value> relayout(RewriteContext &ctx, Value v, VectorLayout src,
.getResult();
});
}
const int src_subelem = src_sublane % packing;
const int dst_subelem = dst_sublane % packing;
const int src_subelem = *src.offsets()[0] % packing;
const int dst_subelem = *dst.offsets()[0] % packing;
if (src_subelem != dst_subelem) {
const int subelem_diff = dst_subelem - src_subelem;
const int shift_bits = bitwidth * std::abs(subelem_diff);
@ -3139,7 +3143,10 @@ FailureOr<Value> relayout(RewriteContext &ctx, Value v, VectorLayout src,
shift_tile = ctx.builder.create<arith::ShRUIOp>(
v.getLoc(), bit_tile, shift_vreg);
}
*tile = shift_tile->getResult(0);
*tile = ctx.builder
.create<tpu::BitcastOp>(v.getLoc(), tile->getType(),
shift_tile->getResult(0))
.getResult();
return absl::OkStatus();
});
}

View File

@ -55,7 +55,7 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref,
return emitError(UnknownLoc::get(memref.getContext()),
"Non-identity affine layout");
}
if (!memref.isIntOrFloat()) {
if (!memref.getElementType().isIntOrFloat()) {
return emitError(UnknownLoc::get(memref.getContext()),
"Invalid element type for memref");
}
@ -91,10 +91,9 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref,
}
tiles.push_back(xla::Tile({32 / bitwidth, 1}));
}
SmallVector<int64_t> tile_strides;
tile_strides.reserve(memref.getRank());
SmallVector<int64_t> tile_strides(memref.getRank());
int64_t stride = 1;
for (int i = memref.getRank() - 1; i > 0; --i) {
for (int i = memref.getRank() - 1; i >= 0; --i) {
tile_strides[i] = stride;
if (i == memref.getRank() - 1) {
stride *= (memref.getShape()[i] + 127) / 128;