mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add support for i1 vmasks with packed tiling and 16-bit comparisons (requires hardware support)
PiperOrigin-RevId: 633677477
This commit is contained in:
parent
0501d3d7a0
commit
0ad5167da8
@ -207,6 +207,15 @@ class RectangularVregBounds : public VRegDataBounds {
|
||||
// one specified as an attribute.
|
||||
// implicit_dim: If specified, the value has an implicit dim inserted in
|
||||
// either minormost or second minormost position.
|
||||
//
|
||||
// Note: There is a special case when VectorLayout is used for an mlir::Value
|
||||
// of i1 type. In this case, we use it to represent a vmask, which has a smaller
|
||||
// bitwidth than a vreg. For these types, the packing() is accurate but the
|
||||
// bitwidth() is a lie, and the i1 value is replicated for every bit.
|
||||
// For example, if the vmask is 8 x 128 x 4 bits and packing() == 2, each 4-bit
|
||||
// register contains two logical bool values which are represented as either b11
|
||||
// or b00. Its usage is currently limited to MLIR arith.cmp and arith.select ops
|
||||
// but we might want to split out a separate class if it gets used more widely.
|
||||
class VectorLayout {
|
||||
public:
|
||||
enum class ImplicitDim {
|
||||
|
@ -499,18 +499,34 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
|
||||
return argument;
|
||||
}
|
||||
|
||||
FailureOr<VectorType> getNativeVregType(
|
||||
Type elem_ty, const std::array<int64_t, 2> target_shape) {
|
||||
FAILUREOR_ASSIGN_OR_RETURN(const int8_t bitwidth,
|
||||
getTypeBitwidth<true>(elem_ty));
|
||||
FailureOr<VectorType> getNativeVregOrVmaskTypeImpl(
|
||||
Type elem_ty, const int8_t bitwidth,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
if (bitwidth == 32) {
|
||||
return VectorType::get(target_shape, elem_ty);
|
||||
}
|
||||
// bitwidth != 32
|
||||
return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth},
|
||||
elem_ty);
|
||||
}
|
||||
|
||||
FailureOr<VectorType> getNativeVregOrVmaskType(
|
||||
Type elem_ty, const int8_t layout_bitwidth,
|
||||
const std::array<int64_t, 2> target_shape) {
|
||||
int8_t bitwidth = elem_ty.getIntOrFloatBitWidth();
|
||||
if (bitwidth == 1) {
|
||||
bitwidth = layout_bitwidth;
|
||||
} else {
|
||||
CHECK_EQ(bitwidth, layout_bitwidth);
|
||||
}
|
||||
return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape);
|
||||
}
|
||||
|
||||
FailureOr<VectorType> getNativeVregType(
|
||||
Type elem_ty, const std::array<int64_t, 2> target_shape) {
|
||||
return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(),
|
||||
target_shape);
|
||||
}
|
||||
|
||||
// Returns empty vector on null attribute
|
||||
FailureOr<SmallVector<Layout>> getLayoutArrayFromAttr(const Attribute attr) {
|
||||
if (const auto array_attr = dyn_cast_if_present<ArrayAttr>(attr)) {
|
||||
@ -610,7 +626,8 @@ LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op,
|
||||
|
||||
FAILUREOR_ASSIGN_OR_RETURN(
|
||||
const VectorType out_vreg_ty,
|
||||
getNativeVregType(out_ty.getElementType(), ctx.target_shape));
|
||||
getNativeVregOrVmaskType(out_ty.getElementType(), layout_out.bitwidth(),
|
||||
ctx.target_shape));
|
||||
|
||||
NamedAttrList attributes(op.getAttrDictionary());
|
||||
attributes.erase("in_layout");
|
||||
|
@ -179,9 +179,8 @@ class VectorLayoutInferer {
|
||||
TPU_CHECK_OP(static_cast<bool>(in_ty) == static_cast<bool>(out_ty),
|
||||
"Input and output are not both vectors?");
|
||||
if (in_ty) {
|
||||
TPU_CHECK_OP(in_ty.getElementTypeBitWidth() == 1 &&
|
||||
out_ty.getElementTypeBitWidth() == 32,
|
||||
"Only 1 bit -> 32 bit extensison supported");
|
||||
TPU_CHECK_OP(in_ty.getElementTypeBitWidth() == 1,
|
||||
"Only extending i1 is supported");
|
||||
}
|
||||
if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) {
|
||||
return failure();
|
||||
@ -193,11 +192,7 @@ class VectorLayoutInferer {
|
||||
auto rhs_ty = dyn_cast<VectorType>(any_op.getOperand(1).getType());
|
||||
TPU_CHECK_OP(static_cast<bool>(lhs_ty) == static_cast<bool>(rhs_ty),
|
||||
"Only one side of cmp is a vector?");
|
||||
if (lhs_ty) {
|
||||
TPU_CHECK_OP(lhs_ty.getElementTypeBitWidth() == kNativeBitwidth &&
|
||||
rhs_ty.getElementTypeBitWidth() == kNativeBitwidth,
|
||||
"Only 32-bit cmp supported");
|
||||
}
|
||||
// TODO(tlongeri): Check that TPU generation supports comparison.
|
||||
if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) {
|
||||
return failure();
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user