[mlir][EmitC] Add MathToEmitC pass for math function lowering to EmitC (#113799)

This commit introduces a new MathToEmitC conversion pass that lowers
selected math operations from the Math dialect to the emitc.call_opaque
operation in the EmitC dialect.

**Supported Math Operations:**
The following operations are converted:

- math.floor -> emitc.call_opaque<"floor">
- math.round -> emitc.call_opaque<"round">
- math.exp -> emitc.call_opaque<"exp">
- math.cos -> emitc.call_opaque<"cos">
- math.sin -> emitc.call_opaque<"sin">
- math.acos -> emitc.call_opaque<"acos">
- math.asin -> emitc.call_opaque<"asin">
- math.atan2 -> emitc.call_opaque<"atan2">
- math.ceil -> emitc.call_opaque<"ceil">
- math.absf -> emitc.call_opaque<"fabs">
- math.powf -> emitc.call_opaque<"pow">

**Target Language Standards:**
The pass supports targeting different language standards:

- C99: Generates calls with suffixes (e.g., floorf, fabsf) for
single-precision floats.
- CPP11: Prepends std:: to functions (e.g., std::floor, std::fabs).

**Design Decisions:**
The pass uses emitc.call_opaque instead of emitc.call to better emulate
C-style function overloading.
emitc.call_opaque does not require a unique type signature, making it
more suitable for operations like <math.h> functions that may be
overloaded for different types.
This design choice ensures compatibility with C/C++ conventions.
This commit is contained in:
Tomer Solomon 2025-01-20 10:26:41 +02:00 committed by GitHub
parent 7a77f14c0a
commit c0055ec434
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 384 additions and 0 deletions

View File

@ -0,0 +1,25 @@
//===- MathToEmitC.h - Math to EmitC Patterns -------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
#include "mlir/Dialect/EmitC/IR/EmitC.h"
namespace mlir {
class RewritePatternSet;
namespace emitc {
/// Enum to specify the language target for EmitC code generation.
enum class LanguageTarget { c99, cpp11 };
} // namespace emitc
void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns,
emitc::LanguageTarget languageTarget);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H

View File

@ -0,0 +1,21 @@
//===- MathToEmitCPass.h - Math to EmitC Pass -------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
#include <memory>
namespace mlir {
class Pass;
#define GEN_PASS_DECL_CONVERTMATHTOEMITC
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H

View File

@ -43,6 +43,7 @@
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"

View File

@ -790,6 +790,28 @@ def ConvertMathToSPIRV : Pass<"convert-math-to-spirv"> {
let dependentDialects = ["spirv::SPIRVDialect"];
}
//===----------------------------------------------------------------------===//
// MathToEmitC
//===----------------------------------------------------------------------===//
def ConvertMathToEmitC : Pass<"convert-math-to-emitc"> {
let summary = "Convert some Math operations to EmitC call_opaque operations";
let description = [{
This pass converts supported Math ops to `call_opaque` ops targeting libc/libm
functions. Unlike convert-math-to-funcs pass, converting to `call_opaque` ops
allows to overload the same function with different argument types.
}];
let dependentDialects = ["emitc::EmitCDialect"];
let options = [
Option<"languageTarget", "language-target", "::mlir::emitc::LanguageTarget",
/*default=*/"::mlir::emitc::LanguageTarget::c99", "Select the language standard target for callees (c99 or cpp11).",
[{::llvm::cl::values(
clEnumValN(::mlir::emitc::LanguageTarget::c99, "c99", "c99"),
clEnumValN(::mlir::emitc::LanguageTarget::cpp11, "cpp11", "cpp11")
)}]>
];
}
//===----------------------------------------------------------------------===//
// MathToFuncs
//===----------------------------------------------------------------------===//

View File

@ -33,6 +33,7 @@ add_subdirectory(IndexToLLVM)
add_subdirectory(IndexToSPIRV)
add_subdirectory(LinalgToStandard)
add_subdirectory(LLVMCommon)
add_subdirectory(MathToEmitC)
add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)

View File

@ -0,0 +1,19 @@
add_mlir_conversion_library(MLIRMathToEmitC
MathToEmitC.cpp
MathToEmitCPass.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToEmitC
DEPENDS
MLIRConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIREmitCDialect
MLIRMathDialect
MLIRPass
MLIRTransformUtils
)

View File

@ -0,0 +1,85 @@
//===- MathToEmitC.cpp - Math to EmitC Patterns -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
namespace {
template <typename OpType>
class LowerToEmitCCallOpaque : public OpRewritePattern<OpType> {
std::string calleeStr;
emitc::LanguageTarget languageTarget;
public:
LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr,
emitc::LanguageTarget languageTarget)
: OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)),
languageTarget(languageTarget) {}
LogicalResult matchAndRewrite(OpType op,
PatternRewriter &rewriter) const override;
};
template <typename OpType>
LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
OpType op, PatternRewriter &rewriter) const {
if (!llvm::all_of(op->getOperandTypes(),
llvm::IsaPred<Float32Type, Float64Type>) ||
!llvm::all_of(op->getResultTypes(),
llvm::IsaPred<Float32Type, Float64Type>))
return rewriter.notifyMatchFailure(
op.getLoc(),
"expected all operands and results to be of type f32 or f64");
std::string modifiedCalleeStr = calleeStr;
if (languageTarget == emitc::LanguageTarget::cpp11) {
modifiedCalleeStr = "std::" + calleeStr;
} else if (languageTarget == emitc::LanguageTarget::c99) {
auto operandType = op->getOperandTypes()[0];
if (operandType.isF32())
modifiedCalleeStr = calleeStr + "f";
}
rewriter.replaceOpWithNewOp<emitc::CallOpaqueOp>(
op, op.getType(), modifiedCalleeStr, op->getOperands());
return success();
}
} // namespace
// Populates patterns to replace `math` operations with `emitc.call_opaque`,
// using function names consistent with those in <math.h>.
void mlir::populateConvertMathToEmitCPatterns(
RewritePatternSet &patterns, emitc::LanguageTarget languageTarget) {
auto *context = patterns.getContext();
patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::RoundOp>>(context, "round",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs",
languageTarget);
patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow",
languageTarget);
}

View File

@ -0,0 +1,53 @@
//===- MathToEmitCPass.cpp - Math to EmitC Pass -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert the Math dialect to the EmitC dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOEMITC
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
// Replaces Math operations with `emitc.call_opaque` operations.
struct ConvertMathToEmitC
: public impl::ConvertMathToEmitCBase<ConvertMathToEmitC> {
using ConvertMathToEmitCBase::ConvertMathToEmitCBase;
public:
void runOnOperation() final;
};
} // namespace
void ConvertMathToEmitC::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalOp<emitc::CallOpaqueOp>();
target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundOp, math::CosOp,
math::SinOp, math::Atan2Op, math::CeilOp, math::AcosOp,
math::AsinOp, math::AbsFOp, math::PowFOp>();
RewritePatternSet patterns(&getContext());
populateConvertMathToEmitCPatterns(patterns, languageTarget);
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}

View File

@ -0,0 +1,23 @@
// RUN: mlir-opt -split-input-file -convert-math-to-emitc -verify-diagnostics %s
func.func @unsupported_tensor_type(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
%0 = math.absf %arg0 : tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
func.func @unsupported_f16_type(%arg0 : f16) -> f16 {
// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
%0 = math.absf %arg0 : f16
return %0 : f16
}
// -----
func.func @unsupported_f128_type(%arg0 : f128) -> f128 {
// expected-error @+1 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
%0 = math.absf %arg0 : f128
return %0 : f128
}

View File

@ -0,0 +1,112 @@
// RUN: mlir-opt -convert-math-to-emitc=language-target=c99 %s | FileCheck %s --check-prefix=c99
// RUN: mlir-opt -convert-math-to-emitc=language-target=cpp11 %s | FileCheck %s --check-prefix=cpp11
func.func @absf(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "fabsf"
// c99-NEXT: emitc.call_opaque "fabs"
// cpp11: emitc.call_opaque "std::fabs"
// cpp11-NEXT: emitc.call_opaque "std::fabs"
%0 = math.absf %arg0 : f32
%1 = math.absf %arg1 : f64
return
}
func.func @floor(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "floorf"
// c99-NEXT: emitc.call_opaque "floor"
// cpp11: emitc.call_opaque "std::floor"
// cpp11-NEXT: emitc.call_opaque "std::floor"
%0 = math.floor %arg0 : f32
%1 = math.floor %arg1 : f64
return
}
func.func @sin(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "sinf"
// c99-NEXT: emitc.call_opaque "sin"
// cpp11: emitc.call_opaque "std::sin"
// cpp11-NEXT: emitc.call_opaque "std::sin"
%0 = math.sin %arg0 : f32
%1 = math.sin %arg1 : f64
return
}
func.func @cos(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "cosf"
// c99-NEXT: emitc.call_opaque "cos"
// cpp11: emitc.call_opaque "std::cos"
// cpp11-NEXT: emitc.call_opaque "std::cos"
%0 = math.cos %arg0 : f32
%1 = math.cos %arg1 : f64
return
}
func.func @asin(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "asinf"
// c99-NEXT: emitc.call_opaque "asin"
// cpp11: emitc.call_opaque "std::asin"
// cpp11-NEXT: emitc.call_opaque "std::asin"
%0 = math.asin %arg0 : f32
%1 = math.asin %arg1 : f64
return
}
func.func @acos(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "acosf"
// c99-NEXT: emitc.call_opaque "acos"
// cpp11: emitc.call_opaque "std::acos"
// cpp11-NEXT: emitc.call_opaque "std::acos"
%0 = math.acos %arg0 : f32
%1 = math.acos %arg1 : f64
return
}
func.func @atan2(%arg0: f32, %arg1: f32, %arg2: f64, %arg3: f64) {
// c99: emitc.call_opaque "atan2f"
// c99-NEXT: emitc.call_opaque "atan2"
// cpp11: emitc.call_opaque "std::atan2"
// cpp11-NEXT: emitc.call_opaque "std::atan2"
%0 = math.atan2 %arg0, %arg1 : f32
%1 = math.atan2 %arg2, %arg3 : f64
return
}
func.func @ceil(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "ceilf"
// c99-NEXT: emitc.call_opaque "ceil"
// cpp11: emitc.call_opaque "std::ceil"
// cpp11-NEXT: emitc.call_opaque "std::ceil"
%0 = math.ceil %arg0 : f32
%1 = math.ceil %arg1 : f64
return
}
func.func @exp(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "expf"
// c99-NEXT: emitc.call_opaque "exp"
// cpp11: emitc.call_opaque "std::exp"
// cpp11-NEXT: emitc.call_opaque "std::exp"
%0 = math.exp %arg0 : f32
%1 = math.exp %arg1 : f64
return
}
func.func @powf(%arg0: f32, %arg1: f32, %arg2: f64, %arg3: f64) {
// c99: emitc.call_opaque "powf"
// c99-NEXT: emitc.call_opaque "pow"
// cpp11: emitc.call_opaque "std::pow"
// cpp11-NEXT: emitc.call_opaque "std::pow"
%0 = math.powf %arg0, %arg1 : f32
%1 = math.powf %arg2, %arg3 : f64
return
}
func.func @round(%arg0: f32, %arg1: f64) {
// c99: emitc.call_opaque "roundf"
// c99-NEXT: emitc.call_opaque "round"
// cpp11: emitc.call_opaque "std::round"
// cpp11-NEXT: emitc.call_opaque "std::round"
%0 = math.round %arg0 : f32
%1 = math.round %arg1 : f64
return
}

View File

@ -4313,6 +4313,7 @@ cc_library(
":IndexToLLVM",
":IndexToSPIRV",
":LinalgToStandard",
":MathToEmitC",
":MathToFuncs",
":MathToLLVM",
":MathToLibm",
@ -8901,6 +8902,27 @@ cc_library(
],
)
cc_library(
name = "MathToEmitC",
srcs = glob([
"lib/Conversion/MathToEmitC/*.cpp",
]),
hdrs = glob([
"include/mlir/Conversion/MathToEmitC/*.h",
]),
includes = [
"include",
"lib/Conversion/MathToEmitC",
],
deps = [
":ConversionPassIncGen",
":EmitCDialect",
":MathDialect",
":Pass",
":TransformUtils",
],
)
cc_library(
name = "MathToFuncs",
srcs = glob(["lib/Conversion/MathToFuncs/*.cpp"]),