mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Mosaic] Relax return type checks for vector.contract
PiperOrigin-RevId: 553898552
This commit is contained in:
parent
6e873ab816
commit
16c33df3cf
@ -715,8 +715,8 @@ class VectorLayoutInferer {
|
||||
};
|
||||
auto res_ty = dyn_cast<VectorType>(op.getType());
|
||||
TPU_CHECK_OP(res_ty, "only vector results supported");
|
||||
TPU_CHECK_OP(res_ty.getElementType().isF32(),
|
||||
"only fp32 matmul results supported");
|
||||
TPU_CHECK_OP(res_ty.getElementTypeBitWidth() == kNativeBitwidth,
|
||||
"only 32-bit matmul results supported");
|
||||
std::array<Layout, 3> in_layout;
|
||||
CHECK_EQ(op->getNumOperands(), 3);
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
|
@ -2317,9 +2317,8 @@ def _vector_contract_rule(ctx: RewriteContext, op: vector.ContractionOp, # pyli
|
||||
lhs_type = ir.VectorType(op.lhs.type)
|
||||
rhs_type = ir.VectorType(op.rhs.type)
|
||||
acc_type = ir.VectorType(op.acc.type)
|
||||
f32 = ir.F32Type.get()
|
||||
if acc_type.element_type != f32:
|
||||
raise NotImplementedError("non-fp32 matmuls")
|
||||
if type_bitwidth(acc_type.element_type) != 32:
|
||||
raise NotImplementedError("non-32-bit matmul result")
|
||||
if lhs_type.shape[0] % layout_lhs.tiling[0] != 0:
|
||||
raise NotImplementedError("layout matmul lhs")
|
||||
if rhs_type.shape == [128, 128]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user