From 9bf1148e74a4190bf5d466488f11d7d3cde4887c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Wed, 1 May 2024 09:36:44 -0700 Subject: [PATCH] [Mosaic] Always define tiling as (1, 128) for 1D loaded or stored vectors (not for the memref), instead of sometimes using (1, 128 * n). They are equivalent - the way values are laid out is the same - but relayouts check specifically for (1, 128). We define (1, 128) to be canonical. PiperOrigin-RevId: 629748121 --- .../dialect/tpu/transforms/infer_vector_layout.cc | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 67ebca2eb..5b93ee36a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -1025,10 +1025,12 @@ class VectorLayoutInferer { TPU_CHECK_OP(tile % target_shape_[1] == 0, "Unsupported tiling for 1D load"); CHECK_EQ(tile_offsets.size(), 1); + // TODO(tlongeri): Also pick a unique (canonical) tiling for packed types + const int64_t lane_tiling = bitwidth == 32 ? target_shape_[1] : tile; // TODO(apaszke): We could generate replicated loads for short values. setLayout(op, in_layout, - VectorLayout(bitwidth, {0, tile_offsets[0]}, {1, tile}, - ImplicitDim::kSecondMinor)); + VectorLayout(bitwidth, {0, tile_offsets[0] % lane_tiling}, + {1, lane_tiling}, ImplicitDim::kSecondMinor)); } else { // rank >= 2 TPU_CHECK_OP(tiling.size() == 2, "Expected 2D tiling in 2D+ loads"); CHECK_EQ(tile_offsets.size(), 2); @@ -1366,9 +1368,12 @@ class VectorLayoutInferer { auto tile = tiling.front(); TPU_CHECK_OP(tile % target_shape_[1] == 0, "Unsupported 1D tiling for 1D store"); + // TODO(tlongeri): Also pick a unique (canonical) tiling for packed types + const int64_t lane_tiling = bitwidth == 32 ? target_shape_[1] : tile; CHECK_EQ(tile_offsets.size(), 1); - store_layout = VectorLayout(bitwidth, {0, tile_offsets[0]}, {1, tile}, - ImplicitDim::kSecondMinor); + store_layout = + VectorLayout(bitwidth, {0, tile_offsets[0] % lane_tiling}, + {1, lane_tiling}, ImplicitDim::kSecondMinor); } else { // rank >= 2 // NOLINT(readability-else-after-return) TPU_CHECK_OP(tiling.size() == 2, "Expected 2D tiling in 2D+ store"); CHECK_EQ(tile_offsets.size(), 2);