mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[Mosaic TPU] Enable non-sublane-aligned bf16 2D load/stores for earlier TPU gens
It is still not efficiently implemented, this is mostly to clean up some logic. We may be able to fuse the creation of masks for different tiles into the creation of a single one. But this is also a problem for the later gens. This also cleans up an unreachable return statement. PiperOrigin-RevId: 714847066
This commit is contained in:
parent
0930289997
commit
7852045582
@ -259,17 +259,13 @@ class TiledRectangularVregBounds : public VRegDataBounds {
|
||||
// TODO(b/300082350): Generalize this
|
||||
return emitError(loc, "Not implemented");
|
||||
}
|
||||
// For older TPUs, we virtualize masking, but only for simple cases.
|
||||
// For older TPUs, we virtualize masking
|
||||
if (generation < 4) {
|
||||
if (num_tiles_ > 1) {
|
||||
return emitError(loc, "Not implemented");
|
||||
}
|
||||
return VectorType::get(target_shape, i1);
|
||||
} else {
|
||||
return VectorType::get(
|
||||
{target_shape[0], target_shape[1], layout_.packing()}, i1);
|
||||
}
|
||||
return VectorType::get({target_shape[0], target_shape[1], 2}, i1);
|
||||
}
|
||||
return VectorType::get(target_shape, i1);
|
||||
}());
|
||||
@ -327,11 +323,6 @@ class TiledRectangularVregBounds : public VRegDataBounds {
|
||||
loc, mask_vreg_ty, start_row, end_row);
|
||||
tile_mask = builder.create<arith::AndIOp>(loc, tile_mask, submask);
|
||||
} else { // generation < 4
|
||||
if (num_tiles_ > 1) {
|
||||
return emitError(loc,
|
||||
"Not implemented: TPU generations before 4 cannot "
|
||||
"handle all bf16 masking");
|
||||
}
|
||||
const auto getMaskCst = [&](const uint64_t v) {
|
||||
const auto int_mask_ty =
|
||||
VectorType::get(target_shape, builder.getI32Type());
|
||||
@ -341,7 +332,7 @@ class TiledRectangularVregBounds : public VRegDataBounds {
|
||||
int_mask_ty, builder.getIntegerAttr(builder.getI32Type(),
|
||||
APInt(32, v))));
|
||||
};
|
||||
Value tile_bitmask = builder.create<arith::SelectOp>(
|
||||
tile_mask = builder.create<arith::SelectOp>(
|
||||
loc, tile_mask, getMaskCst(0xFFFFFFFF), getMaskCst(0));
|
||||
if (start_row % 2 != 0) {
|
||||
auto row_mask = builder.create<tpu::CreateMaskOp>(
|
||||
@ -351,8 +342,8 @@ class TiledRectangularVregBounds : public VRegDataBounds {
|
||||
boundIdxConst(target_shape[1])});
|
||||
auto row_bitmask = builder.create<arith::SelectOp>(
|
||||
loc, row_mask, getMaskCst(0xFFFF0000), getMaskCst(0xFFFFFFFF));
|
||||
tile_bitmask =
|
||||
builder.create<arith::AndIOp>(loc, tile_bitmask, row_bitmask);
|
||||
tile_mask =
|
||||
builder.create<arith::AndIOp>(loc, tile_mask, row_bitmask);
|
||||
}
|
||||
if (end_row % 2 != 0) {
|
||||
auto row_mask = builder.create<tpu::CreateMaskOp>(
|
||||
@ -362,10 +353,9 @@ class TiledRectangularVregBounds : public VRegDataBounds {
|
||||
boundIdxConst(target_shape[1])});
|
||||
auto row_bitmask = builder.create<arith::SelectOp>(
|
||||
loc, row_mask, getMaskCst(0xFFFF), getMaskCst(0xFFFFFFFF));
|
||||
tile_bitmask =
|
||||
builder.create<arith::AndIOp>(loc, tile_bitmask, row_bitmask);
|
||||
tile_mask =
|
||||
builder.create<arith::AndIOp>(loc, tile_mask, row_bitmask);
|
||||
}
|
||||
return cast<TypedValue<VectorType>>(tile_bitmask);
|
||||
}
|
||||
}
|
||||
mask = mask == nullptr
|
||||
|
Loading…
x
Reference in New Issue
Block a user