llvm-project/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
Chris Lattner 41d4aa7de6 [SymbolRefAttr] Revise SymbolRefAttr to hold a StringAttr.
SymbolRefAttr is fundamentally a base string plus a sequence
of nested references.  Instead of storing the string data as
a copies StringRef, store it as an already-uniqued StringAttr.

This makes a lot of things simpler and more efficient because:
1) references to the symbol are already stored as StringAttr's:
   there is no need to copy the string data into MLIRContext
   multiple times.
2) This allows pointer comparisons instead of string
   comparisons (or redundant uniquing) within SymbolTable.cpp.
3) This allows SymbolTable to hold a DenseMap instead of a
   StringMap (which again copies the string data and slows
   lookup).

This is a moderately invasive patch, so I kept a lot of
compatibility APIs around.  It would be nice to explore changing
getName() to return a StringAttr for example (right now you have
to use getNameAttr()), and eliminate things like the StringRef
version of getSymbol.

Differential Revision: https://reviews.llvm.org/D108899
2021-08-29 21:54:47 -07:00

125 lines
4.3 KiB
C++

//===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
#define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
namespace mlir {
/// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func`
/// depending on the element type that Op operates upon. The function
/// declaration is added in case it was not added before.
///
/// If the input values are of f16 type, the value is first casted to f32, the
/// function called and then the result casted back.
///
/// Example with NVVM:
/// %exp_f32 = std.exp %arg_f32 : f32
///
/// will be transformed into
/// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32
template <typename SourceOp>
struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
public:
explicit OpToFuncCallLowering(LLVMTypeConverter &lowering_, StringRef f32Func,
StringRef f64Func)
: ConvertOpToLLVMPattern<SourceOp>(lowering_), f32Func(f32Func),
f64Func(f64Func) {}
LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
using LLVM::LLVMFuncOp;
static_assert(
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
SourceOp>::value,
"expected op with same operand and result types");
SmallVector<Value, 1> castedOperands;
for (Value operand : operands)
castedOperands.push_back(maybeCast(operand, rewriter));
Type resultType = castedOperands.front().getType();
Type funcType = getFunctionType(resultType, castedOperands);
StringRef funcName = getFunctionName(
funcType.cast<LLVM::LLVMFunctionType>().getReturnType());
if (funcName.empty())
return failure();
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
auto callOp = rewriter.create<LLVM::CallOp>(
op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp),
castedOperands);
if (resultType == operands.front().getType()) {
rewriter.replaceOp(op, {callOp.getResult(0)});
return success();
}
Value truncated = rewriter.create<LLVM::FPTruncOp>(
op->getLoc(), operands.front().getType(), callOp.getResult(0));
rewriter.replaceOp(op, {truncated});
return success();
}
private:
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
Type type = operand.getType();
if (!type.isa<Float16Type>())
return operand;
return rewriter.create<LLVM::FPExtOp>(
operand.getLoc(), Float32Type::get(rewriter.getContext()), operand);
}
Type getFunctionType(Type resultType, ArrayRef<Value> operands) const {
SmallVector<Type, 1> operandTypes;
for (Value operand : operands) {
operandTypes.push_back(operand.getType());
}
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}
StringRef getFunctionName(Type type) const {
if (type.isa<Float32Type>())
return f32Func;
if (type.isa<Float64Type>())
return f64Func;
return "";
}
LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
Operation *op) const {
using LLVM::LLVMFuncOp;
auto funcAttr = StringAttr::get(op->getContext(), funcName);
Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
if (funcOp)
return cast<LLVMFuncOp>(*funcOp);
mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
}
const std::string f32Func;
const std::string f64Func;
};
} // namespace mlir
#endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_