llvm-project/mlir/lib/Interfaces/FunctionInterfaces.cpp
Martin Erhart 34a35a8b24 [mlir] Move FunctionInterfaces to Interfaces directory and inherit from CallableOpInterface
Functions are always callable operations and thus every operation
implementing the `FunctionOpInterface` also implements the
`CallableOpInterface`. The only exception was the FuncOp in the toy
example. To make implementation of the `FunctionOpInterface` easier,
this commit lets `FunctionOpInterface` inherit from
`CallableOpInterface` and merges some of their methods. More precisely,
the `CallableOpInterface` has methods to get the argument and result
attributes and a method to get the result types of the callable region.
These methods are always implemented the same way as their analogues in
`FunctionOpInterface` and thus this commit moves all the argument and
result attribute handling methods to the callable interface as well as
the methods to get the argument and result types. The
`FuntionOpInterface` then does not have to declare them as well, but
just inherits them from the `CallableOpInterface`.
Adding the inheritance relation also required to move the
`FunctionOpInterface` from the IR directory to the Interfaces directory
since IR should not depend on Interfaces.

Reviewed By: jpienaar, springerm

Differential Revision: https://reviews.llvm.org/D157988
2023-08-31 11:28:23 +00:00

362 lines
14 KiB
C++

//===- FunctionSupport.cpp - Utility types for function-like ops ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/FunctionInterfaces.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Tablegen Interface Definitions
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/FunctionInterfaces.cpp.inc"
//===----------------------------------------------------------------------===//
// Function Arguments and Results.
//===----------------------------------------------------------------------===//
static bool isEmptyAttrDict(Attribute attr) {
return llvm::cast<DictionaryAttr>(attr).empty();
}
DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op,
unsigned index) {
ArrayAttr attrs = op.getArgAttrsAttr();
DictionaryAttr argAttrs =
attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
return argAttrs;
}
DictionaryAttr
function_interface_impl::getResultAttrDict(FunctionOpInterface op,
unsigned index) {
ArrayAttr attrs = op.getResAttrsAttr();
DictionaryAttr resAttrs =
attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr();
return resAttrs;
}
ArrayRef<NamedAttribute>
function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) {
auto argDict = getArgAttrDict(op, index);
return argDict ? argDict.getValue() : std::nullopt;
}
ArrayRef<NamedAttribute>
function_interface_impl::getResultAttrs(FunctionOpInterface op,
unsigned index) {
auto resultDict = getResultAttrDict(op, index);
return resultDict ? resultDict.getValue() : std::nullopt;
}
/// Get either the argument or result attributes array.
template <bool isArg>
static ArrayAttr getArgResAttrs(FunctionOpInterface op) {
if constexpr (isArg)
return op.getArgAttrsAttr();
else
return op.getResAttrsAttr();
}
/// Set either the argument or result attributes array.
template <bool isArg>
static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) {
if constexpr (isArg)
op.setArgAttrsAttr(attrs);
else
op.setResAttrsAttr(attrs);
}
/// Erase either the argument or result attributes array.
template <bool isArg>
static void removeArgResAttrs(FunctionOpInterface op) {
if constexpr (isArg)
op.removeArgAttrsAttr();
else
op.removeResAttrsAttr();
}
/// Set all of the argument or result attribute dictionaries for a function.
template <bool isArg>
static void setAllArgResAttrDicts(FunctionOpInterface op,
ArrayRef<Attribute> attrs) {
if (llvm::all_of(attrs, isEmptyAttrDict))
removeArgResAttrs<isArg>(op);
else
setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs));
}
void function_interface_impl::setAllArgAttrDicts(
FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
}
void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op,
ArrayRef<Attribute> attrs) {
auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
return !attr ? DictionaryAttr::get(op->getContext()) : attr;
});
setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(wrappedAttrs));
}
void function_interface_impl::setAllResultAttrDicts(
FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) {
setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size()));
}
void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op,
ArrayRef<Attribute> attrs) {
auto wrappedAttrs = llvm::map_range(attrs, [op](Attribute attr) -> Attribute {
return !attr ? DictionaryAttr::get(op->getContext()) : attr;
});
setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(wrappedAttrs));
}
/// Update the given index into an argument or result attribute dictionary.
template <bool isArg>
static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices,
unsigned index, DictionaryAttr attrs) {
ArrayAttr allAttrs = getArgResAttrs<isArg>(op);
if (!allAttrs) {
if (attrs.empty())
return;
// If this attribute is not empty, we need to create a new attribute array.
SmallVector<Attribute, 8> newAttrs(numTotalIndices,
DictionaryAttr::get(op->getContext()));
newAttrs[index] = attrs;
setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
return;
}
// Check to see if the attribute is different from what we already have.
if (allAttrs[index] == attrs)
return;
// If it is, check to see if the attribute array would now contain only empty
// dictionaries.
ArrayRef<Attribute> rawAttrArray = allAttrs.getValue();
if (attrs.empty() &&
llvm::all_of(rawAttrArray.take_front(index), isEmptyAttrDict) &&
llvm::all_of(rawAttrArray.drop_front(index + 1), isEmptyAttrDict))
return removeArgResAttrs<isArg>(op);
// Otherwise, create a new attribute array with the updated dictionary.
SmallVector<Attribute, 8> newAttrs(rawAttrArray.begin(), rawAttrArray.end());
newAttrs[index] = attrs;
setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs));
}
void function_interface_impl::setArgAttrs(FunctionOpInterface op,
unsigned index,
ArrayRef<NamedAttribute> attributes) {
assert(index < op.getNumArguments() && "invalid argument number");
return setArgResAttrDict</*isArg=*/true>(
op, op.getNumArguments(), index,
DictionaryAttr::get(op->getContext(), attributes));
}
void function_interface_impl::setArgAttrs(FunctionOpInterface op,
unsigned index,
DictionaryAttr attributes) {
return setArgResAttrDict</*isArg=*/true>(
op, op.getNumArguments(), index,
attributes ? attributes : DictionaryAttr::get(op->getContext()));
}
void function_interface_impl::setResultAttrs(
FunctionOpInterface op, unsigned index,
ArrayRef<NamedAttribute> attributes) {
assert(index < op.getNumResults() && "invalid result number");
return setArgResAttrDict</*isArg=*/false>(
op, op.getNumResults(), index,
DictionaryAttr::get(op->getContext(), attributes));
}
void function_interface_impl::setResultAttrs(FunctionOpInterface op,
unsigned index,
DictionaryAttr attributes) {
assert(index < op.getNumResults() && "invalid result number");
return setArgResAttrDict</*isArg=*/false>(
op, op.getNumResults(), index,
attributes ? attributes : DictionaryAttr::get(op->getContext()));
}
void function_interface_impl::insertFunctionArguments(
FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes,
ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs,
unsigned originalNumArgs, Type newType) {
assert(argIndices.size() == argTypes.size());
assert(argIndices.size() == argAttrs.size() || argAttrs.empty());
assert(argIndices.size() == argLocs.size());
if (argIndices.empty())
return;
// There are 3 things that need to be updated:
// - Function type.
// - Arg attrs.
// - Block arguments of entry block.
Block &entry = op->getRegion(0).front();
// Update the argument attributes of the function.
ArrayAttr oldArgAttrs = op.getArgAttrsAttr();
if (oldArgAttrs || !argAttrs.empty()) {
SmallVector<DictionaryAttr, 4> newArgAttrs;
newArgAttrs.reserve(originalNumArgs + argIndices.size());
unsigned oldIdx = 0;
auto migrate = [&](unsigned untilIdx) {
if (!oldArgAttrs) {
newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx);
} else {
auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>();
newArgAttrs.append(oldArgAttrRange.begin() + oldIdx,
oldArgAttrRange.begin() + untilIdx);
}
oldIdx = untilIdx;
};
for (unsigned i = 0, e = argIndices.size(); i < e; ++i) {
migrate(argIndices[i]);
newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]);
}
migrate(originalNumArgs);
setAllArgAttrDicts(op, newArgAttrs);
}
// Update the function type and any entry block arguments.
op.setFunctionTypeAttr(TypeAttr::get(newType));
for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
entry.insertArgument(argIndices[i] + i, argTypes[i], argLocs[i]);
}
void function_interface_impl::insertFunctionResults(
FunctionOpInterface op, ArrayRef<unsigned> resultIndices,
TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs,
unsigned originalNumResults, Type newType) {
assert(resultIndices.size() == resultTypes.size());
assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty());
if (resultIndices.empty())
return;
// There are 2 things that need to be updated:
// - Function type.
// - Result attrs.
// Update the result attributes of the function.
ArrayAttr oldResultAttrs = op.getResAttrsAttr();
if (oldResultAttrs || !resultAttrs.empty()) {
SmallVector<DictionaryAttr, 4> newResultAttrs;
newResultAttrs.reserve(originalNumResults + resultIndices.size());
unsigned oldIdx = 0;
auto migrate = [&](unsigned untilIdx) {
if (!oldResultAttrs) {
newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx);
} else {
auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>();
newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx,
oldResultAttrsRange.begin() + untilIdx);
}
oldIdx = untilIdx;
};
for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) {
migrate(resultIndices[i]);
newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{}
: resultAttrs[i]);
}
migrate(originalNumResults);
setAllResultAttrDicts(op, newResultAttrs);
}
// Update the function type.
op.setFunctionTypeAttr(TypeAttr::get(newType));
}
void function_interface_impl::eraseFunctionArguments(
FunctionOpInterface op, const BitVector &argIndices, Type newType) {
// There are 3 things that need to be updated:
// - Function type.
// - Arg attrs.
// - Block arguments of entry block.
Block &entry = op->getRegion(0).front();
// Update the argument attributes of the function.
if (ArrayAttr argAttrs = op.getArgAttrsAttr()) {
SmallVector<DictionaryAttr, 4> newArgAttrs;
newArgAttrs.reserve(argAttrs.size());
for (unsigned i = 0, e = argIndices.size(); i < e; ++i)
if (!argIndices[i])
newArgAttrs.emplace_back(llvm::cast<DictionaryAttr>(argAttrs[i]));
setAllArgAttrDicts(op, newArgAttrs);
}
// Update the function type and any entry block arguments.
op.setFunctionTypeAttr(TypeAttr::get(newType));
entry.eraseArguments(argIndices);
}
void function_interface_impl::eraseFunctionResults(
FunctionOpInterface op, const BitVector &resultIndices, Type newType) {
// There are 2 things that need to be updated:
// - Function type.
// - Result attrs.
// Update the result attributes of the function.
if (ArrayAttr resAttrs = op.getResAttrsAttr()) {
SmallVector<DictionaryAttr, 4> newResultAttrs;
newResultAttrs.reserve(resAttrs.size());
for (unsigned i = 0, e = resultIndices.size(); i < e; ++i)
if (!resultIndices[i])
newResultAttrs.emplace_back(llvm::cast<DictionaryAttr>(resAttrs[i]));
setAllResultAttrDicts(op, newResultAttrs);
}
// Update the function type.
op.setFunctionTypeAttr(TypeAttr::get(newType));
}
//===----------------------------------------------------------------------===//
// Function type signature.
//===----------------------------------------------------------------------===//
void function_interface_impl::setFunctionType(FunctionOpInterface op,
Type newType) {
unsigned oldNumArgs = op.getNumArguments();
unsigned oldNumResults = op.getNumResults();
op.setFunctionTypeAttr(TypeAttr::get(newType));
unsigned newNumArgs = op.getNumArguments();
unsigned newNumResults = op.getNumResults();
// Functor used to update the argument and result attributes of the function.
auto emptyDict = DictionaryAttr::get(op.getContext());
auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) {
constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>;
if (oldCount == newCount)
return;
// The new type has no arguments/results, just drop the attribute.
if (newCount == 0)
return removeArgResAttrs<isArgVal>(op);
ArrayAttr attrs = getArgResAttrs<isArgVal>(op);
if (!attrs)
return;
// The new type has less arguments/results, take the first N attributes.
if (newCount < oldCount)
return setAllArgResAttrDicts<isArgVal>(
op, attrs.getValue().take_front(newCount));
// Otherwise, the new type has more arguments/results. Initialize the new
// arguments/results with empty dictionary attributes.
SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end());
newAttrs.resize(newCount, emptyDict);
setAllArgResAttrDicts<isArgVal>(op, newAttrs);
};
// Update the argument and result attributes.
updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs);
updateAttrFn(std::false_type{}, oldNumResults, newNumResults);
}