[mlir][emitc] Support convert arith.extf and arith.truncf to emitc (#121184)

This commit is contained in:
Jianjian Guan 2025-01-16 14:57:43 +08:00 committed by GitHub
parent 0195ec452e
commit f9a8006247
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 106 additions and 1 deletions

View File

@ -733,6 +733,43 @@ public:
}
};
// Floating-point to floating-point conversions.
template <typename CastOp>
class FpCastOpConversion : public OpConversionPattern<CastOp> {
public:
FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern<CastOp>(typeConverter, context) {}
LogicalResult
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Vectors in particular are not supported.
Type operandType = adaptor.getIn().getType();
if (!emitc::isSupportedFloatType(operandType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast source type");
if (auto roundingModeOp =
dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
// Only supporting default rounding mode as of now.
if (roundingModeOp.getRoundingModeAttr())
return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode");
}
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
if (!emitc::isSupportedFloatType(dstType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");
Value fpCastOperand = adaptor.getIn();
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
@ -778,7 +815,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
ItoFCastOpConversion<arith::SIToFPOp>,
ItoFCastOpConversion<arith::UIToFPOp>,
FtoICastOpConversion<arith::FPToSIOp>,
FtoICastOpConversion<arith::FPToUIOp>
FtoICastOpConversion<arith::FPToUIOp>,
FpCastOpConversion<arith::ExtFOp>,
FpCastOpConversion<arith::TruncFOp>
>(typeConverter, ctx);
// clang-format on
}

View File

@ -149,3 +149,43 @@ func.func @arith_remui_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) -> vec
%divui = arith.remui %arg0, %arg1 : vector<5xi32>
return %divui: vector<5xi32>
}
// -----
func.func @arith_truncf(%arg0: f64) -> f32 {
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
%truncd = arith.truncf %arg0 to_nearest_away : f64 to f32
return %truncd : f32
}
// -----
func.func @arith_extf_f128(%arg0: f32) -> f128 {
// expected-error @+1 {{failed to legalize operation 'arith.extf'}}
%extd = arith.extf %arg0 : f32 to f128
return %extd : f128
}
// -----
func.func @arith_truncf_f128(%arg0: f128) -> f32 {
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
%truncd = arith.truncf %arg0 : f128 to f32
return %truncd : f32
}
// -----
func.func @arith_extf_vector(%arg0: vector<4xf32>) -> vector<4xf64> {
// expected-error @+1 {{failed to legalize operation 'arith.extf'}}
%extd = arith.extf %arg0 : vector<4xf32> to vector<4xf64>
return %extd : vector<4xf64>
}
// -----
func.func @arith_truncf_vector(%arg0: vector<4xf64>) -> vector<4xf32> {
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
%truncd = arith.truncf %arg0 : vector<4xf64> to vector<4xf32>
return %truncd : vector<4xf32>
}

View File

@ -739,3 +739,29 @@ func.func @arith_divui_remui(%arg0: i32, %arg1: i32) -> i32 {
return %div : i32
}
// -----
func.func @arith_extf(%arg0: f16) -> f64 {
// CHECK-LABEL: arith_extf
// CHECK-SAME: (%[[Arg0:[^ ]*]]: f16)
// CHECK: %[[Extd0:.*]] = emitc.cast %[[Arg0]] : f16 to f32
%extd0 = arith.extf %arg0 : f16 to f32
// CHECK: %[[Extd1:.*]] = emitc.cast %[[Extd0]] : f32 to f64
%extd1 = arith.extf %extd0 : f32 to f64
return %extd1 : f64
}
// -----
func.func @arith_truncf(%arg0: f64) -> f16 {
// CHECK-LABEL: arith_truncf
// CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
// CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
%truncd0 = arith.truncf %arg0 : f64 to f32
// CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
%truncd1 = arith.truncf %truncd0 : f32 to f16
return %truncd1 : f16
}