Add support for i1 vmasks with packed tiling and 16-bit comparisons (requires hardware support)

PiperOrigin-RevId: 633677477
This commit is contained in:
Tomás Longeri 2024-05-14 12:54:02 -07:00 committed by jax authors
parent 0501d3d7a0
commit 0ad5167da8
3 changed files with 35 additions and 14 deletions

View File

@ -207,6 +207,15 @@ class RectangularVregBounds : public VRegDataBounds {
// one specified as an attribute.
// implicit_dim: If specified, the value has an implicit dim inserted in
// either minormost or second minormost position.
//
// Note: There is a special case when VectorLayout is used for an mlir::Value
// of i1 type. In this case, we use it to represent a vmask, which has a smaller
// bitwidth than a vreg. For these types, the packing() is accurate but the
// bitwidth() is a lie, and the i1 value is replicated for every bit.
// For example, if the vmask is 8 x 128 x 4 bits and packing() == 2, each 4-bit
// register contains two logical bool values which are represented as either b11
// or b00. Its usage is currently limited to MLIR arith.cmp and arith.select ops
// but we might want to split out a separate class if it gets used more widely.
class VectorLayout {
public:
enum class ImplicitDim {

View File

@ -499,18 +499,34 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
return argument;
}
FailureOr<VectorType> getNativeVregType(
Type elem_ty, const std::array<int64_t, 2> target_shape) {
FAILUREOR_ASSIGN_OR_RETURN(const int8_t bitwidth,
getTypeBitwidth<true>(elem_ty));
FailureOr<VectorType> getNativeVregOrVmaskTypeImpl(
Type elem_ty, const int8_t bitwidth,
const std::array<int64_t, 2> target_shape) {
if (bitwidth == 32) {
return VectorType::get(target_shape, elem_ty);
}
// bitwidth != 32
return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth},
elem_ty);
}
FailureOr<VectorType> getNativeVregOrVmaskType(
Type elem_ty, const int8_t layout_bitwidth,
const std::array<int64_t, 2> target_shape) {
int8_t bitwidth = elem_ty.getIntOrFloatBitWidth();
if (bitwidth == 1) {
bitwidth = layout_bitwidth;
} else {
CHECK_EQ(bitwidth, layout_bitwidth);
}
return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape);
}
FailureOr<VectorType> getNativeVregType(
Type elem_ty, const std::array<int64_t, 2> target_shape) {
return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(),
target_shape);
}
// Returns empty vector on null attribute
FailureOr<SmallVector<Layout>> getLayoutArrayFromAttr(const Attribute attr) {
if (const auto array_attr = dyn_cast_if_present<ArrayAttr>(attr)) {
@ -610,7 +626,8 @@ LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op,
FAILUREOR_ASSIGN_OR_RETURN(
const VectorType out_vreg_ty,
getNativeVregType(out_ty.getElementType(), ctx.target_shape));
getNativeVregOrVmaskType(out_ty.getElementType(), layout_out.bitwidth(),
ctx.target_shape));
NamedAttrList attributes(op.getAttrDictionary());
attributes.erase("in_layout");

View File

@ -179,9 +179,8 @@ class VectorLayoutInferer {
TPU_CHECK_OP(static_cast<bool>(in_ty) == static_cast<bool>(out_ty),
"Input and output are not both vectors?");
if (in_ty) {
TPU_CHECK_OP(in_ty.getElementTypeBitWidth() == 1 &&
out_ty.getElementTypeBitWidth() == 32,
"Only 1 bit -> 32 bit extensison supported");
TPU_CHECK_OP(in_ty.getElementTypeBitWidth() == 1,
"Only extending i1 is supported");
}
if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) {
return failure();
@ -193,11 +192,7 @@ class VectorLayoutInferer {
auto rhs_ty = dyn_cast<VectorType>(any_op.getOperand(1).getType());
TPU_CHECK_OP(static_cast<bool>(lhs_ty) == static_cast<bool>(rhs_ty),
"Only one side of cmp is a vector?");
if (lhs_ty) {
TPU_CHECK_OP(lhs_ty.getElementTypeBitWidth() == kNativeBitwidth &&
rhs_ty.getElementTypeBitWidth() == kNativeBitwidth,
"Only 32-bit cmp supported");
}
// TODO(tlongeri): Check that TPU generation supports comparison.
if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) {
return failure();
}