[Mosaic TPU] Change getLayout to force offset to 0 when inferring input has offset out of the first tile.

PiperOrigin-RevId: 684145987
This commit is contained in:
Jevin Jiang 2024-10-09 13:11:14 -07:00 committed by jax authors
parent 64a757450c
commit f52b016de1

View File

@ -144,11 +144,9 @@ class VectorLayoutInferer {
if (!isa<vector::BroadcastOp, vector::ExtractStridedSliceOp>(any_op)) {
const SmallVector<Layout> layouts_in = getLayoutFromOperands(&any_op);
for (const Layout &layout : layouts_in) {
if (layout && layout->offsets()[1].has_value() &&
layout->offsets()[1].value() > layout->tiling()[1]) {
return any_op.emitOpError(
"Not implemented: Inferring from input offsets outside of the "
"first tile");
if (layout &&
layout->offsets()[1].value_or(0) >= layout->tiling()[1]) {
force_first_tile_offsets_ = true;
}
}
}
@ -349,6 +347,7 @@ class VectorLayoutInferer {
}
CHECK(any_op.getNumResults() == 0 || any_op.hasAttr("out_layout"));
CHECK(any_op.getNumOperands() == 0 || any_op.hasAttr("in_layout"));
force_first_tile_offsets_ = false;
}
return match_terminator(block.getTerminator());
}
@ -1940,7 +1939,14 @@ class VectorLayoutInferer {
auto result_index = op_result.getResultNumber();
auto out_attrs = op->getAttrOfType<ArrayAttr>("out_layout").getValue();
CHECK(out_attrs.size() > result_index);
return cast<VectorLayoutAttr>(out_attrs[result_index]).getLayout();
auto layout = cast<VectorLayoutAttr>(out_attrs[result_index]).getLayout();
if (force_first_tile_offsets_ &&
layout->offsets()[1].value_or(0) >= layout->tiling()[1]) {
// Force the out-of-first-tile offset to be zero.
layout = VectorLayout(layout->bitwidth(), {layout->offsets()[0], 0},
layout->tiling(), layout->implicit_dim());
}
return layout;
}
SmallVector<Layout, 4> getLayoutFromOperands(Operation *op) {
@ -2024,6 +2030,10 @@ class VectorLayoutInferer {
std::array<int64_t, 2> target_shape_;
std::array<int64_t, 2> default_tiling_;
// TODO(b/342235360): Deprecate force_first_tile_offsets_ once we fully
// remove the restriction that offsets must fall within the first tile.
bool force_first_tile_offsets_ = false;
// Address alignment requirement, counted in 32-bit increments.
static constexpr int64_t kVmemAlignment32 = 128;
// TODO(apaszke): This is not really native on newer generations of TPUs.