mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[Mosaic] apply_vector_layout C++ rewrite: Use strided load to load single row
Corresponds to changes from cl/564831610 PiperOrigin-RevId: 567231873
This commit is contained in:
parent
f304eb89b9
commit
3228453bf6
@ -545,7 +545,12 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
|
||||
Tiling memref_tiling,
|
||||
getMemRefTiling(load_op.getBase(), ctx.target_shape));
|
||||
if (layout_out.tiling() != memref_tiling) {
|
||||
return op.emitOpError("Not implemented");
|
||||
// Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes).
|
||||
// TODO(b/295393167): need to support strided load for bitwidth < 32.
|
||||
if (layout_out.bitwidth() != 32 ||
|
||||
layout_out.tiling() != std::array<int64_t, 2>{1, ctx.target_shape[1]}) {
|
||||
return op.emitOpError("Not implemented");
|
||||
}
|
||||
}
|
||||
// TODO(apaszke): Check that loads are from vmem!
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
@ -561,7 +566,13 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
|
||||
}
|
||||
const SmallVector<int64_t> implicit_shape =
|
||||
layout_out.implicitShape(vty.getShape());
|
||||
const auto ss = implicit_shape[implicit_shape.size() - 2];
|
||||
const int64_t ss = implicit_shape[implicit_shape.size() - 2];
|
||||
int64_t sublane_stride = 1;
|
||||
if (layout_out.bitwidth() == 32 &&
|
||||
layout_out.tiling() == std::array<int64_t, 2>{1, ctx.target_shape[1]} &&
|
||||
ss == 1) {
|
||||
sublane_stride = memref_tiling[0];
|
||||
}
|
||||
const LayoutOffsets offsets = layout_out.offsets();
|
||||
AffineMap load_map;
|
||||
arith::ConstantOp padding;
|
||||
@ -623,7 +634,8 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
|
||||
CHECK(offsets[0].has_value());
|
||||
tile = ctx.builder.create<tpu::LoadOp>(
|
||||
load_op.getLoc(), target_ty, load_op.getBase(), idxs_vs,
|
||||
bounds->getSublaneMask(mlir_ctx, ctx.target_shape), nullptr);
|
||||
bounds->getSublaneMask(mlir_ctx, ctx.target_shape),
|
||||
ctx.builder.getI32IntegerAttr(sublane_stride));
|
||||
} else {
|
||||
if (load_map) {
|
||||
CHECK(padding);
|
||||
@ -638,9 +650,10 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
|
||||
const SmallVector<bool> sublane_mask(ctx.target_shape[0], true);
|
||||
const auto sublane_mask_attr =
|
||||
DenseBoolArrayAttr::get(mlir_ctx, sublane_mask);
|
||||
tile = ctx.builder.create<tpu::LoadOp>(load_op.getLoc(), target_ty,
|
||||
load_op.getBase(), idxs_vs,
|
||||
sublane_mask_attr, nullptr);
|
||||
tile = ctx.builder.create<tpu::LoadOp>(
|
||||
load_op.getLoc(), target_ty, load_op.getBase(), idxs_vs,
|
||||
sublane_mask_attr,
|
||||
ctx.builder.getI32IntegerAttr(sublane_stride));
|
||||
}
|
||||
}
|
||||
tiles(tile_idxs) = tile->getResult(0);
|
||||
|
Loading…
x
Reference in New Issue
Block a user