From 0ad5167da88a4d7cfffc67603cf5b5174f9c02bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Tue, 14 May 2024 12:54:02 -0700 Subject: [PATCH] Add support for i1 vmasks with packed tiling and 16-bit comparisons (requires hardware support) PiperOrigin-RevId: 633677477 --- jaxlib/mosaic/dialect/tpu/layout.h | 9 ++++++ .../tpu/transforms/apply_vector_layout.cc | 29 +++++++++++++++---- .../tpu/transforms/infer_vector_layout.cc | 11 ++----- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/layout.h b/jaxlib/mosaic/dialect/tpu/layout.h index 516c06360..88f262269 100644 --- a/jaxlib/mosaic/dialect/tpu/layout.h +++ b/jaxlib/mosaic/dialect/tpu/layout.h @@ -207,6 +207,15 @@ class RectangularVregBounds : public VRegDataBounds { // one specified as an attribute. // implicit_dim: If specified, the value has an implicit dim inserted in // either minormost or second minormost position. +// +// Note: There is a special case when VectorLayout is used for an mlir::Value +// of i1 type. In this case, we use it to represent a vmask, which has a smaller +// bitwidth than a vreg. For these types, the packing() is accurate but the +// bitwidth() is a lie, and the i1 value is replicated for every bit. +// For example, if the vmask is 8 x 128 x 4 bits and packing() == 2, each 4-bit +// register contains two logical bool values which are represented as either b11 +// or b00. Its usage is currently limited to MLIR arith.cmp and arith.select ops +// but we might want to split out a separate class if it gets used more widely. class VectorLayout { public: enum class ImplicitDim { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 072c2f557..fed7a9b15 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -499,18 +499,34 @@ FailureOr appendConstant(RewriteContext &ctx, return argument; } -FailureOr getNativeVregType( - Type elem_ty, const std::array target_shape) { - FAILUREOR_ASSIGN_OR_RETURN(const int8_t bitwidth, - getTypeBitwidth(elem_ty)); +FailureOr getNativeVregOrVmaskTypeImpl( + Type elem_ty, const int8_t bitwidth, + const std::array target_shape) { if (bitwidth == 32) { return VectorType::get(target_shape, elem_ty); } - // bitwidth != 32 return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth}, elem_ty); } +FailureOr getNativeVregOrVmaskType( + Type elem_ty, const int8_t layout_bitwidth, + const std::array target_shape) { + int8_t bitwidth = elem_ty.getIntOrFloatBitWidth(); + if (bitwidth == 1) { + bitwidth = layout_bitwidth; + } else { + CHECK_EQ(bitwidth, layout_bitwidth); + } + return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape); +} + +FailureOr getNativeVregType( + Type elem_ty, const std::array target_shape) { + return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(), + target_shape); +} + // Returns empty vector on null attribute FailureOr> getLayoutArrayFromAttr(const Attribute attr) { if (const auto array_attr = dyn_cast_if_present(attr)) { @@ -610,7 +626,8 @@ LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op, FAILUREOR_ASSIGN_OR_RETURN( const VectorType out_vreg_ty, - getNativeVregType(out_ty.getElementType(), ctx.target_shape)); + getNativeVregOrVmaskType(out_ty.getElementType(), layout_out.bitwidth(), + ctx.target_shape)); NamedAttrList attributes(op.getAttrDictionary()); attributes.erase("in_layout"); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 7c70c9526..daa260711 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -179,9 +179,8 @@ class VectorLayoutInferer { TPU_CHECK_OP(static_cast(in_ty) == static_cast(out_ty), "Input and output are not both vectors?"); if (in_ty) { - TPU_CHECK_OP(in_ty.getElementTypeBitWidth() == 1 && - out_ty.getElementTypeBitWidth() == 32, - "Only 1 bit -> 32 bit extensison supported"); + TPU_CHECK_OP(in_ty.getElementTypeBitWidth() == 1, + "Only extending i1 is supported"); } if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) { return failure(); @@ -193,11 +192,7 @@ class VectorLayoutInferer { auto rhs_ty = dyn_cast(any_op.getOperand(1).getType()); TPU_CHECK_OP(static_cast(lhs_ty) == static_cast(rhs_ty), "Only one side of cmp is a vector?"); - if (lhs_ty) { - TPU_CHECK_OP(lhs_ty.getElementTypeBitWidth() == kNativeBitwidth && - rhs_ty.getElementTypeBitWidth() == kNativeBitwidth, - "Only 32-bit cmp supported"); - } + // TODO(tlongeri): Check that TPU generation supports comparison. if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) { return failure(); }