mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-27 18:36:05 +00:00
[mlir][emitc] Support convert arith.extf and arith.truncf to emitc (#121184)
This commit is contained in:
parent
0195ec452e
commit
f9a8006247
@ -733,6 +733,43 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// Floating-point to floating-point conversions.
|
||||
template <typename CastOp>
|
||||
class FpCastOpConversion : public OpConversionPattern<CastOp> {
|
||||
public:
|
||||
FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
|
||||
: OpConversionPattern<CastOp>(typeConverter, context) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Vectors in particular are not supported.
|
||||
Type operandType = adaptor.getIn().getType();
|
||||
if (!emitc::isSupportedFloatType(operandType))
|
||||
return rewriter.notifyMatchFailure(castOp,
|
||||
"unsupported cast source type");
|
||||
if (auto roundingModeOp =
|
||||
dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
|
||||
// Only supporting default rounding mode as of now.
|
||||
if (roundingModeOp.getRoundingModeAttr())
|
||||
return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode");
|
||||
}
|
||||
|
||||
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
|
||||
if (!dstType)
|
||||
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
|
||||
|
||||
if (!emitc::isSupportedFloatType(dstType))
|
||||
return rewriter.notifyMatchFailure(castOp,
|
||||
"unsupported cast destination type");
|
||||
|
||||
Value fpCastOperand = adaptor.getIn();
|
||||
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -778,7 +815,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
|
||||
ItoFCastOpConversion<arith::SIToFPOp>,
|
||||
ItoFCastOpConversion<arith::UIToFPOp>,
|
||||
FtoICastOpConversion<arith::FPToSIOp>,
|
||||
FtoICastOpConversion<arith::FPToUIOp>
|
||||
FtoICastOpConversion<arith::FPToUIOp>,
|
||||
FpCastOpConversion<arith::ExtFOp>,
|
||||
FpCastOpConversion<arith::TruncFOp>
|
||||
>(typeConverter, ctx);
|
||||
// clang-format on
|
||||
}
|
||||
|
@ -149,3 +149,43 @@ func.func @arith_remui_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) -> vec
|
||||
%divui = arith.remui %arg0, %arg1 : vector<5xi32>
|
||||
return %divui: vector<5xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @arith_truncf(%arg0: f64) -> f32 {
|
||||
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
|
||||
%truncd = arith.truncf %arg0 to_nearest_away : f64 to f32
|
||||
return %truncd : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @arith_extf_f128(%arg0: f32) -> f128 {
|
||||
// expected-error @+1 {{failed to legalize operation 'arith.extf'}}
|
||||
%extd = arith.extf %arg0 : f32 to f128
|
||||
return %extd : f128
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @arith_truncf_f128(%arg0: f128) -> f32 {
|
||||
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
|
||||
%truncd = arith.truncf %arg0 : f128 to f32
|
||||
return %truncd : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @arith_extf_vector(%arg0: vector<4xf32>) -> vector<4xf64> {
|
||||
// expected-error @+1 {{failed to legalize operation 'arith.extf'}}
|
||||
%extd = arith.extf %arg0 : vector<4xf32> to vector<4xf64>
|
||||
return %extd : vector<4xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @arith_truncf_vector(%arg0: vector<4xf64>) -> vector<4xf32> {
|
||||
// expected-error @+1 {{failed to legalize operation 'arith.truncf'}}
|
||||
%truncd = arith.truncf %arg0 : vector<4xf64> to vector<4xf32>
|
||||
return %truncd : vector<4xf32>
|
||||
}
|
||||
|
@ -739,3 +739,29 @@ func.func @arith_divui_remui(%arg0: i32, %arg1: i32) -> i32 {
|
||||
|
||||
return %div : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @arith_extf(%arg0: f16) -> f64 {
|
||||
// CHECK-LABEL: arith_extf
|
||||
// CHECK-SAME: (%[[Arg0:[^ ]*]]: f16)
|
||||
// CHECK: %[[Extd0:.*]] = emitc.cast %[[Arg0]] : f16 to f32
|
||||
%extd0 = arith.extf %arg0 : f16 to f32
|
||||
// CHECK: %[[Extd1:.*]] = emitc.cast %[[Extd0]] : f32 to f64
|
||||
%extd1 = arith.extf %extd0 : f32 to f64
|
||||
|
||||
return %extd1 : f64
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @arith_truncf(%arg0: f64) -> f16 {
|
||||
// CHECK-LABEL: arith_truncf
|
||||
// CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
|
||||
// CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
|
||||
%truncd0 = arith.truncf %arg0 : f64 to f32
|
||||
// CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
|
||||
%truncd1 = arith.truncf %truncd0 : f32 to f16
|
||||
|
||||
return %truncd1 : f16
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user