[Mosaic] apply_vector_layout C++ rewrite: Use strided load to load single row

Corresponds to changes from cl/564831610

PiperOrigin-RevId: 567231873
This commit is contained in:
Tomás Longeri 2023-09-21 01:59:33 -07:00 committed by jax authors
parent f304eb89b9
commit 3228453bf6

View File

@ -545,7 +545,12 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
Tiling memref_tiling,
getMemRefTiling(load_op.getBase(), ctx.target_shape));
if (layout_out.tiling() != memref_tiling) {
return op.emitOpError("Not implemented");
// Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes).
// TODO(b/295393167): need to support strided load for bitwidth < 32.
if (layout_out.bitwidth() != 32 ||
layout_out.tiling() != std::array<int64_t, 2>{1, ctx.target_shape[1]}) {
return op.emitOpError("Not implemented");
}
}
// TODO(apaszke): Check that loads are from vmem!
FAILUREOR_ASSIGN_OR_RETURN(
@ -561,7 +566,13 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
}
const SmallVector<int64_t> implicit_shape =
layout_out.implicitShape(vty.getShape());
const auto ss = implicit_shape[implicit_shape.size() - 2];
const int64_t ss = implicit_shape[implicit_shape.size() - 2];
int64_t sublane_stride = 1;
if (layout_out.bitwidth() == 32 &&
layout_out.tiling() == std::array<int64_t, 2>{1, ctx.target_shape[1]} &&
ss == 1) {
sublane_stride = memref_tiling[0];
}
const LayoutOffsets offsets = layout_out.offsets();
AffineMap load_map;
arith::ConstantOp padding;
@ -623,7 +634,8 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
CHECK(offsets[0].has_value());
tile = ctx.builder.create<tpu::LoadOp>(
load_op.getLoc(), target_ty, load_op.getBase(), idxs_vs,
bounds->getSublaneMask(mlir_ctx, ctx.target_shape), nullptr);
bounds->getSublaneMask(mlir_ctx, ctx.target_shape),
ctx.builder.getI32IntegerAttr(sublane_stride));
} else {
if (load_map) {
CHECK(padding);
@ -638,9 +650,10 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
const SmallVector<bool> sublane_mask(ctx.target_shape[0], true);
const auto sublane_mask_attr =
DenseBoolArrayAttr::get(mlir_ctx, sublane_mask);
tile = ctx.builder.create<tpu::LoadOp>(load_op.getLoc(), target_ty,
load_op.getBase(), idxs_vs,
sublane_mask_attr, nullptr);
tile = ctx.builder.create<tpu::LoadOp>(
load_op.getLoc(), target_ty, load_op.getBase(), idxs_vs,
sublane_mask_attr,
ctx.builder.getI32IntegerAttr(sublane_stride));
}
}
tiles(tile_idxs) = tile->getResult(0);