[Mosaic TPU] The previous change does not actually force the input offsets read by the rules, but simply disables all the checks. Reverting so that we at least regain the checks until we have a proper fix.

Reverts 4a596aee1e8920f5b51d5bd573df976390bbd437

PiperOrigin-RevId: 680925509
This commit is contained in:
Adam Paszke 2024-10-01 02:23:13 -07:00 committed by jax authors
parent 80f963c003
commit f62941d126
2 changed files with 6 additions and 16 deletions

View File

@ -142,12 +142,13 @@ class VectorLayoutInferer {
// support for offsets outside of the first tile. When support is more
// broad, any op without support should check it within their own rule.
if (!isa<vector::BroadcastOp, vector::ExtractStridedSliceOp>(any_op)) {
SmallVector<Layout> layouts_in = getLayoutFromOperands(&any_op);
for (auto &layout : layouts_in) {
const SmallVector<Layout> layouts_in = getLayoutFromOperands(&any_op);
for (const Layout &layout : layouts_in) {
if (layout && layout->offsets()[1].has_value() &&
layout->offsets()[1].value() >= layout->tiling()[1]) {
layout = VectorLayout(layout->bitwidth(), {layout->offsets()[0], 0},
layout->tiling(), layout->implicit_dim());
layout->offsets()[1].value() > layout->tiling()[1]) {
return any_op.emitOpError(
"Not implemented: Inferring from input offsets outside of the "
"first tile");
}
}
}

View File

@ -233,17 +233,6 @@ class OpsTest(PallasBaseTest):
assert (run(cond, lhs, rhs) == lhs).all()
def test_offset_oob(self):
# TODO(b/342235360): Remove this test once we have a better way to handle
# out-of-first-tile offsets.
def body(x_ref, o_ref):
o_ref[...] = x_ref[...][:, 130:230]
out = jax.ShapeDtypeStruct((8, 100), jnp.int16)
x = jnp.arange(8 * 256, dtype=jnp.int16).reshape((8, 256))
result = self.pallas_call(body, out_shape=out)(x)
np.testing.assert_array_equal(result, x[:, 130:230])
class OpsInterpretTest(OpsTest):
INTERPRET = True