mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[Mosaic TPU] Change getLayout to force offset to 0 when inferring input has offset out of the first tile.
PiperOrigin-RevId: 684145987
This commit is contained in:
parent
64a757450c
commit
f52b016de1
@ -144,11 +144,9 @@ class VectorLayoutInferer {
|
||||
if (!isa<vector::BroadcastOp, vector::ExtractStridedSliceOp>(any_op)) {
|
||||
const SmallVector<Layout> layouts_in = getLayoutFromOperands(&any_op);
|
||||
for (const Layout &layout : layouts_in) {
|
||||
if (layout && layout->offsets()[1].has_value() &&
|
||||
layout->offsets()[1].value() > layout->tiling()[1]) {
|
||||
return any_op.emitOpError(
|
||||
"Not implemented: Inferring from input offsets outside of the "
|
||||
"first tile");
|
||||
if (layout &&
|
||||
layout->offsets()[1].value_or(0) >= layout->tiling()[1]) {
|
||||
force_first_tile_offsets_ = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -349,6 +347,7 @@ class VectorLayoutInferer {
|
||||
}
|
||||
CHECK(any_op.getNumResults() == 0 || any_op.hasAttr("out_layout"));
|
||||
CHECK(any_op.getNumOperands() == 0 || any_op.hasAttr("in_layout"));
|
||||
force_first_tile_offsets_ = false;
|
||||
}
|
||||
return match_terminator(block.getTerminator());
|
||||
}
|
||||
@ -1940,7 +1939,14 @@ class VectorLayoutInferer {
|
||||
auto result_index = op_result.getResultNumber();
|
||||
auto out_attrs = op->getAttrOfType<ArrayAttr>("out_layout").getValue();
|
||||
CHECK(out_attrs.size() > result_index);
|
||||
return cast<VectorLayoutAttr>(out_attrs[result_index]).getLayout();
|
||||
auto layout = cast<VectorLayoutAttr>(out_attrs[result_index]).getLayout();
|
||||
if (force_first_tile_offsets_ &&
|
||||
layout->offsets()[1].value_or(0) >= layout->tiling()[1]) {
|
||||
// Force the out-of-first-tile offset to be zero.
|
||||
layout = VectorLayout(layout->bitwidth(), {layout->offsets()[0], 0},
|
||||
layout->tiling(), layout->implicit_dim());
|
||||
}
|
||||
return layout;
|
||||
}
|
||||
|
||||
SmallVector<Layout, 4> getLayoutFromOperands(Operation *op) {
|
||||
@ -2024,6 +2030,10 @@ class VectorLayoutInferer {
|
||||
std::array<int64_t, 2> target_shape_;
|
||||
std::array<int64_t, 2> default_tiling_;
|
||||
|
||||
// TODO(b/342235360): Deprecate force_first_tile_offsets_ once we fully
|
||||
// remove the restriction that offsets must fall within the first tile.
|
||||
bool force_first_tile_offsets_ = false;
|
||||
|
||||
// Address alignment requirement, counted in 32-bit increments.
|
||||
static constexpr int64_t kVmemAlignment32 = 128;
|
||||
// TODO(apaszke): This is not really native on newer generations of TPUs.
|
||||
|
Loading…
x
Reference in New Issue
Block a user