mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Mosaic] apply_vector_layout C++ rewrite: various bug fixes
PiperOrigin-RevId: 571075082
This commit is contained in:
parent
295cecd505
commit
ab4a8e3417
@ -276,9 +276,13 @@ class TiledRectangularVregBounds : public VRegDataBounds {
|
||||
return VectorType::get(target_shape, i1);
|
||||
}());
|
||||
if (isComplete(target_shape)) {
|
||||
builder.create<arith::ConstantOp>(
|
||||
loc, mask_vreg_ty,
|
||||
DenseElementsAttr::get(mask_vreg_ty, builder.getBoolAttr(true)));
|
||||
return cast<TypedValue<VectorType>>(
|
||||
builder
|
||||
.create<arith::ConstantOp>(
|
||||
loc, mask_vreg_ty,
|
||||
DenseElementsAttr::get(mask_vreg_ty,
|
||||
builder.getBoolAttr(true)))
|
||||
.getResult());
|
||||
}
|
||||
Value mask = nullptr;
|
||||
CHECK_GE(num_tiles_, 0);
|
||||
@ -488,7 +492,7 @@ llvm::SmallVector<int64_t> VectorLayout::tileArrayShape(
|
||||
tiles_shape.pop_back();
|
||||
break;
|
||||
case ImplicitDim::kSecondMinor:
|
||||
tiles_shape.erase(tiles_shape.end() - 1);
|
||||
tiles_shape.erase(tiles_shape.end() - 2);
|
||||
break;
|
||||
}
|
||||
return tiles_shape;
|
||||
|
@ -60,9 +60,9 @@ struct VRegDataBounds {
|
||||
std::array<int64_t, 2> target_shape) const = 0;
|
||||
|
||||
bool isComplete(const std::array<int64_t, 2> target_shape) const {
|
||||
return maskVariesAlong(Direction::kSublanes, target_shape) ||
|
||||
maskVariesAlong(Direction::kLanes, target_shape) ||
|
||||
maskVariesAlong(Direction::kSubelements, target_shape);
|
||||
return !maskVariesAlong(Direction::kSublanes, target_shape) &&
|
||||
!maskVariesAlong(Direction::kLanes, target_shape) &&
|
||||
!maskVariesAlong(Direction::kSubelements, target_shape);
|
||||
}
|
||||
|
||||
// Constructs a vector mask value that is true iff the entry contains useful
|
||||
|
@ -83,7 +83,8 @@ struct RewriteContext {
|
||||
|
||||
LogicalResult applyLayoutBlock(RewriteContext &ctx, Block &block);
|
||||
RollVectorsOp assemble(RewriteContext &ctx, VectorType vty,
|
||||
const VectorLayout &layout, xla::Array<Value> vals);
|
||||
const VectorLayout &layout,
|
||||
const xla::Array<Value> &vals);
|
||||
FailureOr<xla::Array<Value>> disassemble(RewriteContext &ctx,
|
||||
const VectorLayout &layout, Value val);
|
||||
namespace {
|
||||
@ -319,6 +320,9 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
|
||||
vector_constants.push_back(value);
|
||||
ctx.func->setAttr("vector_constants",
|
||||
ArrayAttr::get(ctx.func.getContext(), vector_constants));
|
||||
} else {
|
||||
ctx.func->setAttr("vector_constants",
|
||||
ArrayAttr::get(ctx.func.getContext(), value));
|
||||
}
|
||||
// Adjust window params for the extra operand.
|
||||
if (auto window_params =
|
||||
@ -337,7 +341,7 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
|
||||
StringAttr::get(ctx.func.getContext(), "transform_indices"),
|
||||
AffineMapAttr::get(transform_indices)));
|
||||
SmallVector<Attribute> window_params_values(window_params.getValue());
|
||||
window_params_values.push_back(new_param);
|
||||
window_params_values.insert(window_params_values.end() - 1, new_param);
|
||||
ctx.func->setAttr("window_params", ArrayAttr::get(ctx.func.getContext(),
|
||||
window_params_values));
|
||||
}
|
||||
@ -650,7 +654,7 @@ LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op,
|
||||
CHECK(layouts_in.front().has_value());
|
||||
CHECK(layouts_out.front().has_value());
|
||||
auto extf_op = cast<arith::ExtFOp>(op);
|
||||
if (layouts_in.front()->bitwidth() != 32 ||
|
||||
if (layouts_in.front()->bitwidth() != 16 ||
|
||||
layouts_out.front()->bitwidth() != 32) {
|
||||
return op.emitOpError("Only 16-bit to 32-bit conversion supported");
|
||||
}
|
||||
@ -2526,7 +2530,7 @@ const llvm::StringMap<rule_type> &rules() {
|
||||
rules_elementwise_op_entry<math::ExpOp, 1>(),
|
||||
rules_elementwise_op_entry<math::CosOp, 1>(),
|
||||
rules_elementwise_op_entry<math::SinOp, 1>(),
|
||||
rules_elementwise_op_entry<math::PowFOp, 1>(),
|
||||
rules_elementwise_op_entry<math::PowFOp, 2>(),
|
||||
rules_elementwise_op_entry<math::RsqrtOp, 1>(),
|
||||
rules_elementwise_op_entry<math::TanhOp, 1>(),
|
||||
{func::ReturnOp::getOperationName(), func_return_rule},
|
||||
@ -2554,7 +2558,7 @@ const llvm::StringMap<rule_type> &rules() {
|
||||
|
||||
RollVectorsOp assemble(RewriteContext &ctx, VectorType vty,
|
||||
const VectorLayout &layout,
|
||||
const xla::Array<Value> vals) {
|
||||
const xla::Array<Value> &vals) {
|
||||
CHECK(vals.dimensions() ==
|
||||
layout.tileArrayShape(vty.getShape(), ctx.target_shape));
|
||||
CHECK_GT(vals.num_elements(), 0);
|
||||
@ -2950,11 +2954,11 @@ FailureOr<Value> relayout(RewriteContext &ctx, Value v, VectorLayout src,
|
||||
disassemble(ctx, src, v));
|
||||
SmallVector<int64_t> dst_tiles_shape =
|
||||
dst.tileArrayShape(vty.getShape(), ctx.target_shape);
|
||||
if (src.generalizes(dst, vty.getShape(), ctx.target_shape) &&
|
||||
src.tilesPerVreg(ctx.target_shape) == 1) {
|
||||
if (src.generalizes(dst, vty.getShape(), ctx.target_shape)) {
|
||||
return assemble(ctx, vty, dst, std::move(src_tiles)).getResult();
|
||||
}
|
||||
if (!src.offsets()[0].has_value() && !src.offsets()[1].has_value()) {
|
||||
if (!src.offsets()[0].has_value() && !src.offsets()[1].has_value() &&
|
||||
src.tilesPerVreg(ctx.target_shape) == 1) {
|
||||
// A fully replicated value is always easy to relayout
|
||||
// It would be nice to be able to assert this here, but given replicated
|
||||
// values our rules can introduce equivalent expressions.
|
||||
@ -3057,12 +3061,12 @@ FailureOr<Value> relayout(RewriteContext &ctx, Value v, VectorLayout src,
|
||||
src_idx[src_idx.size() - 2] *= 4;
|
||||
src_idx[src_idx.size() - 1] /= 4;
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
parts[i] = ctx.builder.create<tpu::UnpackSubelementsOp>(
|
||||
v.getLoc(), vreg_i32, src_tiles(src_idx), vreg_part);
|
||||
if (src_idx[src_idx.size() - 2] <
|
||||
src_tiles.dim(src_tiles.num_dimensions() - 2) - 1) {
|
||||
++src_idx[src_idx.size() - 2];
|
||||
}
|
||||
parts[i] = ctx.builder.create<tpu::UnpackSubelementsOp>(
|
||||
v.getLoc(), vreg_i32, src_tiles(src_idx), vreg_part);
|
||||
}
|
||||
*tile = ctx.builder.create<tpu::PackSubelementsOp>(
|
||||
v.getLoc(), src_tiles.begin()->getType(), parts);
|
||||
@ -3117,8 +3121,8 @@ FailureOr<Value> relayout(RewriteContext &ctx, Value v, VectorLayout src,
|
||||
.getResult();
|
||||
});
|
||||
}
|
||||
const int src_subelem = src_sublane % packing;
|
||||
const int dst_subelem = dst_sublane % packing;
|
||||
const int src_subelem = *src.offsets()[0] % packing;
|
||||
const int dst_subelem = *dst.offsets()[0] % packing;
|
||||
if (src_subelem != dst_subelem) {
|
||||
const int subelem_diff = dst_subelem - src_subelem;
|
||||
const int shift_bits = bitwidth * std::abs(subelem_diff);
|
||||
@ -3139,7 +3143,10 @@ FailureOr<Value> relayout(RewriteContext &ctx, Value v, VectorLayout src,
|
||||
shift_tile = ctx.builder.create<arith::ShRUIOp>(
|
||||
v.getLoc(), bit_tile, shift_vreg);
|
||||
}
|
||||
*tile = shift_tile->getResult(0);
|
||||
*tile = ctx.builder
|
||||
.create<tpu::BitcastOp>(v.getLoc(), tile->getType(),
|
||||
shift_tile->getResult(0))
|
||||
.getResult();
|
||||
return absl::OkStatus();
|
||||
});
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref,
|
||||
return emitError(UnknownLoc::get(memref.getContext()),
|
||||
"Non-identity affine layout");
|
||||
}
|
||||
if (!memref.isIntOrFloat()) {
|
||||
if (!memref.getElementType().isIntOrFloat()) {
|
||||
return emitError(UnknownLoc::get(memref.getContext()),
|
||||
"Invalid element type for memref");
|
||||
}
|
||||
@ -91,10 +91,9 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref,
|
||||
}
|
||||
tiles.push_back(xla::Tile({32 / bitwidth, 1}));
|
||||
}
|
||||
SmallVector<int64_t> tile_strides;
|
||||
tile_strides.reserve(memref.getRank());
|
||||
SmallVector<int64_t> tile_strides(memref.getRank());
|
||||
int64_t stride = 1;
|
||||
for (int i = memref.getRank() - 1; i > 0; --i) {
|
||||
for (int i = memref.getRank() - 1; i >= 0; --i) {
|
||||
tile_strides[i] = stride;
|
||||
if (i == memref.getRank() - 1) {
|
||||
stride *= (memref.getShape()[i] + 127) / 128;
|
||||
|
Loading…
x
Reference in New Issue
Block a user