[Mosaic] Packed loads and stores with 1D tiling should use (1, 128 * packing)

There are multiple representations for 1D tiling in vector layouts and we need to choose one consistently.

PiperOrigin-RevId: 638331061
This commit is contained in:
Tomás Longeri 2024-05-29 10:24:20 -07:00 committed by jax authors
parent f0eaba9e29
commit 8f8b976421
2 changed files with 20 additions and 10 deletions

View File

@ -2461,7 +2461,9 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
FAILUREOR_ASSIGN_OR_RETURN(
Tiling memref_tiling,
getMemRefTiling(load_op.getBase(), ctx.target_shape));
if (layout_out.tiling() != memref_tiling) {
if (memref_tiling != layout_out.tiling() &&
!(memref_tiling[0] == 1 && layout_out.tiling()[0] == 1 &&
memref_tiling[1] % layout_out.tiling()[1] == 0)) {
// Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes).
// TODO(b/295393167): need to support strided load for bitwidth < 32.
if (layout_out.bitwidth() != 32 ||
@ -3659,7 +3661,9 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
FAILUREOR_ASSIGN_OR_RETURN(
const Tiling memref_tiling,
getMemRefTiling(store_op.getBase(), ctx.target_shape));
if (to_store_layout.tiling() != memref_tiling) {
if (memref_tiling != to_store_layout.tiling() &&
!(memref_tiling[0] == 1 && to_store_layout.tiling()[0] == 1 &&
memref_tiling[1] % to_store_layout.tiling()[1] == 0)) {
// Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes).
// TODO(b/295393167): need to support strided store for bitwidth < 32.
if (to_store_layout.bitwidth() != 32 ||

View File

@ -1019,6 +1019,10 @@ class VectorLayoutInferer {
"memref and vector rank mismatch");
int64_t rank = res_ty.getRank();
int8_t bitwidth = res_ty.getElementTypeBitWidth();
if (kNativeBitwidth % bitwidth != 0) {
return op.emitOpError("Unsupported bitwidth");
}
const int packing = kNativeBitwidth / bitwidth;
auto maybe_tiling =
verifyMemoryTiling(op, getMemRefLayout(op.getBase()).getTiles(),
src_ty.getRank(), src_ty.getElementTypeBitWidth());
@ -1050,12 +1054,10 @@ class VectorLayoutInferer {
}
if (rank == 1) {
TPU_CHECK_OP(tiling.size() == 1, "Expected 1D tiling in 1D loads");
const int64_t lane_tiling = packing * target_shape_[1];
auto tile = tiling.front();
TPU_CHECK_OP(tile % target_shape_[1] == 0,
"Unsupported tiling for 1D load");
TPU_CHECK_OP(tile % lane_tiling == 0, "Unsupported tiling for 1D load");
CHECK_EQ(tile_offsets.size(), 1);
// TODO(tlongeri): Also pick a unique (canonical) tiling for packed types
const int64_t lane_tiling = bitwidth == 32 ? target_shape_[1] : tile;
// TODO(apaszke): We could generate replicated loads for short values.
setLayout(op, in_layout,
VectorLayout(bitwidth, {0, tile_offsets[0] % lane_tiling},
@ -1372,6 +1374,10 @@ class VectorLayoutInferer {
"memref and vector rank mismatch");
int64_t rank = ref_ty.getRank();
int8_t bitwidth = store_ty.getElementTypeBitWidth();
if (kNativeBitwidth % bitwidth != 0) {
return op.emitOpError("Unsupported bitwidth");
}
const int packing = kNativeBitwidth / bitwidth;
auto maybe_tiling =
verifyMemoryTiling(op, getMemRefLayout(op.getBase()).getTiles(),
ref_ty.getRank(), ref_ty.getElementTypeBitWidth());
@ -1402,11 +1408,10 @@ class VectorLayoutInferer {
}
if (rank == 1) {
TPU_CHECK_OP(tiling.size() == 1, "Expected 1D tiling in 1D store");
const int64_t lane_tiling = packing * target_shape_[1];
auto tile = tiling.front();
TPU_CHECK_OP(tile % target_shape_[1] == 0,
TPU_CHECK_OP(tile % lane_tiling == 0,
"Unsupported 1D tiling for 1D store");
// TODO(tlongeri): Also pick a unique (canonical) tiling for packed types
const int64_t lane_tiling = bitwidth == 32 ? target_shape_[1] : tile;
CHECK_EQ(tile_offsets.size(), 1);
store_layout =
VectorLayout(bitwidth, {0, tile_offsets[0] % lane_tiling},
@ -1806,6 +1811,7 @@ class VectorLayoutInferer {
std::optional<absl::Span<const int64_t>> verifyMemoryTiling(
Operation *op, ArrayRef<xla::Tile> mem_tiling, int64_t rank,
int8_t bitwidth) {
const int packing = kNativeBitwidth / bitwidth;
if (bitwidth == 32) {
if (mem_tiling.size() != 1) {
op->emitOpError("Only one-level tiling supported for 32-bit loads");
@ -1822,7 +1828,7 @@ class VectorLayoutInferer {
}
auto first = mem_tiling[0].dimensions();
auto second = mem_tiling[1].dimensions();
if (first.size() != 1 || first[0] % target_shape_[1] != 0) {
if (first.size() != 1 || first[0] % (packing * target_shape_[1]) != 0) {
op->emitOpError("Invalid first-level tile in 1D memory op");
return std::nullopt;
}