[Mosaic] Always define tiling as (1, 128) for 1D loaded or stored vectors (not for the memref), instead of sometimes using (1, 128 * n).

They are equivalent - the way values are laid out is the same - but relayouts check specifically for (1, 128). We define (1, 128) to be canonical.

PiperOrigin-RevId: 629748121
This commit is contained in:
Tomás Longeri 2024-05-01 09:36:44 -07:00 committed by jax authors
parent 26049b1059
commit 9bf1148e74

View File

@ -1025,10 +1025,12 @@ class VectorLayoutInferer {
TPU_CHECK_OP(tile % target_shape_[1] == 0,
"Unsupported tiling for 1D load");
CHECK_EQ(tile_offsets.size(), 1);
// TODO(tlongeri): Also pick a unique (canonical) tiling for packed types
const int64_t lane_tiling = bitwidth == 32 ? target_shape_[1] : tile;
// TODO(apaszke): We could generate replicated loads for short values.
setLayout(op, in_layout,
VectorLayout(bitwidth, {0, tile_offsets[0]}, {1, tile},
ImplicitDim::kSecondMinor));
VectorLayout(bitwidth, {0, tile_offsets[0] % lane_tiling},
{1, lane_tiling}, ImplicitDim::kSecondMinor));
} else { // rank >= 2
TPU_CHECK_OP(tiling.size() == 2, "Expected 2D tiling in 2D+ loads");
CHECK_EQ(tile_offsets.size(), 2);
@ -1366,9 +1368,12 @@ class VectorLayoutInferer {
auto tile = tiling.front();
TPU_CHECK_OP(tile % target_shape_[1] == 0,
"Unsupported 1D tiling for 1D store");
// TODO(tlongeri): Also pick a unique (canonical) tiling for packed types
const int64_t lane_tiling = bitwidth == 32 ? target_shape_[1] : tile;
CHECK_EQ(tile_offsets.size(), 1);
store_layout = VectorLayout(bitwidth, {0, tile_offsets[0]}, {1, tile},
ImplicitDim::kSecondMinor);
store_layout =
VectorLayout(bitwidth, {0, tile_offsets[0] % lane_tiling},
{1, lane_tiling}, ImplicitDim::kSecondMinor);
} else { // rank >= 2 // NOLINT(readability-else-after-return)
TPU_CHECK_OP(tiling.size() == 2, "Expected 2D tiling in 2D+ store");
CHECK_EQ(tile_offsets.size(), 2);