[Mosaic] Relax return type checks for vector.contract

PiperOrigin-RevId: 553898552
This commit is contained in:
Adam Paszke 2023-08-04 13:33:38 -07:00 committed by jax authors
parent 6e873ab816
commit 16c33df3cf
2 changed files with 4 additions and 5 deletions

View File

@ -715,8 +715,8 @@ class VectorLayoutInferer {
};
auto res_ty = dyn_cast<VectorType>(op.getType());
TPU_CHECK_OP(res_ty, "only vector results supported");
TPU_CHECK_OP(res_ty.getElementType().isF32(),
"only fp32 matmul results supported");
TPU_CHECK_OP(res_ty.getElementTypeBitWidth() == kNativeBitwidth,
"only 32-bit matmul results supported");
std::array<Layout, 3> in_layout;
CHECK_EQ(op->getNumOperands(), 3);
for (int i = 0; i < 3; ++i) {

View File

@ -2317,9 +2317,8 @@ def _vector_contract_rule(ctx: RewriteContext, op: vector.ContractionOp, # pyli
lhs_type = ir.VectorType(op.lhs.type)
rhs_type = ir.VectorType(op.rhs.type)
acc_type = ir.VectorType(op.acc.type)
f32 = ir.F32Type.get()
if acc_type.element_type != f32:
raise NotImplementedError("non-fp32 matmuls")
if type_bitwidth(acc_type.element_type) != 32:
raise NotImplementedError("non-32-bit matmul result")
if lhs_type.shape[0] % layout_lhs.tiling[0] != 0:
raise NotImplementedError("layout matmul lhs")
if rhs_type.shape == [128, 128]: