[Mosaic TPU] Enable non-sublane-aligned bf16 2D load/stores for earlier TPU gens

It is still not efficiently implemented, this is mostly to clean up some logic. We may be able to fuse the creation of masks for different tiles into the creation of a single one. But this is also a problem for the later gens.

This also cleans up an unreachable return statement.

PiperOrigin-RevId: 714847066
This commit is contained in:
Tomás Longeri 2025-01-12 23:58:10 -08:00 committed by jax authors
parent 0930289997
commit 7852045582

View File

@ -259,17 +259,13 @@ class TiledRectangularVregBounds : public VRegDataBounds {
// TODO(b/300082350): Generalize this
return emitError(loc, "Not implemented");
}
// For older TPUs, we virtualize masking, but only for simple cases.
// For older TPUs, we virtualize masking
if (generation < 4) {
if (num_tiles_ > 1) {
return emitError(loc, "Not implemented");
}
return VectorType::get(target_shape, i1);
} else {
return VectorType::get(
{target_shape[0], target_shape[1], layout_.packing()}, i1);
}
return VectorType::get({target_shape[0], target_shape[1], 2}, i1);
}
return VectorType::get(target_shape, i1);
}());
@ -327,11 +323,6 @@ class TiledRectangularVregBounds : public VRegDataBounds {
loc, mask_vreg_ty, start_row, end_row);
tile_mask = builder.create<arith::AndIOp>(loc, tile_mask, submask);
} else { // generation < 4
if (num_tiles_ > 1) {
return emitError(loc,
"Not implemented: TPU generations before 4 cannot "
"handle all bf16 masking");
}
const auto getMaskCst = [&](const uint64_t v) {
const auto int_mask_ty =
VectorType::get(target_shape, builder.getI32Type());
@ -341,7 +332,7 @@ class TiledRectangularVregBounds : public VRegDataBounds {
int_mask_ty, builder.getIntegerAttr(builder.getI32Type(),
APInt(32, v))));
};
Value tile_bitmask = builder.create<arith::SelectOp>(
tile_mask = builder.create<arith::SelectOp>(
loc, tile_mask, getMaskCst(0xFFFFFFFF), getMaskCst(0));
if (start_row % 2 != 0) {
auto row_mask = builder.create<tpu::CreateMaskOp>(
@ -351,8 +342,8 @@ class TiledRectangularVregBounds : public VRegDataBounds {
boundIdxConst(target_shape[1])});
auto row_bitmask = builder.create<arith::SelectOp>(
loc, row_mask, getMaskCst(0xFFFF0000), getMaskCst(0xFFFFFFFF));
tile_bitmask =
builder.create<arith::AndIOp>(loc, tile_bitmask, row_bitmask);
tile_mask =
builder.create<arith::AndIOp>(loc, tile_mask, row_bitmask);
}
if (end_row % 2 != 0) {
auto row_mask = builder.create<tpu::CreateMaskOp>(
@ -362,10 +353,9 @@ class TiledRectangularVregBounds : public VRegDataBounds {
boundIdxConst(target_shape[1])});
auto row_bitmask = builder.create<arith::SelectOp>(
loc, row_mask, getMaskCst(0xFFFF), getMaskCst(0xFFFFFFFF));
tile_bitmask =
builder.create<arith::AndIOp>(loc, tile_bitmask, row_bitmask);
tile_mask =
builder.create<arith::AndIOp>(loc, tile_mask, row_bitmask);
}
return cast<TypedValue<VectorType>>(tile_bitmask);
}
}
mask = mask == nullptr