Revert "[mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments (#96207)"

This reverts commit f1e0657d144f5a3cfef4b625d0f875f4dacd21d1.

It breaks SCF conversion, see test case on the PR.
This commit is contained in:
Benjamin Kramer 2024-06-27 09:15:15 +02:00
parent 605098dcd4
commit 4d46b460f9
4 changed files with 159 additions and 122 deletions

View File

@ -246,13 +246,6 @@ depending on the situation.
- An argument materialization is used when converting the type of a block
argument during a [signature conversion](#region-signature-conversion).
The new block argument types are specified in a `SignatureConversion`
object. An original block argument can be converted into multiple
block arguments, which is not supported everywhere in the dialect
conversion. (E.g., adaptors support only a single replacement value for
each original value.) Therefore, an argument materialization is used to
convert potentially multiple new block arguments back into a single SSA
value.
* Source Materialization
@ -266,9 +259,6 @@ depending on the situation.
* When a block argument has been converted to a different type, but
the original argument still has users that will remain live after
the conversion process has finished.
* When a block argument has been dropped, but the argument still has
users that will remain live after the conversion process has
finished.
* When the result type of an operation has been converted to a
different type, but the original result still has users that will
remain live after the conversion process is finished.
@ -338,22 +328,19 @@ class TypeConverter {
registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
}
/// All of the following materializations require function objects that are
/// convertible to the following form:
/// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
/// where `T` is any subclass of `Type`. This function is responsible for
/// creating an operation, using the OpBuilder and Location provided, that
/// "casts" a range of values into a single value of the given type `T`. It
/// must return a Value of the converted type on success, an `std::nullopt` if
/// it failed but other materialization can be attempted, and `nullptr` on
/// unrecoverable failure. It will only be called for (sub)types of `T`.
/// Materialization functions must be provided when a type conversion may
/// persist after the conversion has finished.
/// Register a materialization function, which must be convertible to the
/// following form:
/// `Optional<Value> (OpBuilder &, T, ValueRange, Location)`,
/// where `T` is any subclass of `Type`.
/// This function is responsible for creating an operation, using the
/// OpBuilder and Location provided, that "converts" a range of values into a
/// single value of the given type `T`. It must return a Value of the
/// converted type on success, an `std::nullopt` if it failed but other
/// materialization can be attempted, and `nullptr` on unrecoverable failure.
/// It will only be called for (sub)types of `T`.
///
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
/// a signature conversion of a single block argument, to a single SSA value
/// of a legal type.
/// converting an illegal block argument type, to a legal type.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
@ -361,9 +348,8 @@ class TypeConverter {
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
/// converting a legal replacement value back to an illegal source type.
/// This is used when some uses of the original, illegal value must persist
/// beyond the main conversion.
/// converting a legal type to an illegal source type. This is used when
/// conversions to an illegal type must persist beyond the main conversion.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addSourceMaterialization(FnT &&callback) {
@ -371,7 +357,7 @@ class TypeConverter {
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
/// converting an illegal (source) value to a legal (target) type.
/// converting type from an illegal, or source, type to a legal type.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {

View File

@ -168,8 +168,8 @@ public:
registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
}
/// All of the following materializations require function objects that are
/// convertible to the following form:
/// Register a materialization function, which must be convertible to the
/// following form:
/// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
/// where `T` is any subclass of `Type`. This function is responsible for
/// creating an operation, using the OpBuilder and Location provided, that
@ -179,11 +179,9 @@ public:
/// unrecoverable failure. It will only be called for (sub)types of `T`.
/// Materialization functions must be provided when a type conversion may
/// persist after the conversion has finished.
///
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
/// a signature conversion of a single block argument, to a single SSA value
/// of a legal type.
/// converting an illegal block argument type, to a legal type.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
@ -191,9 +189,8 @@ public:
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
/// converting a legal replacement value back to an illegal source type.
/// This is used when some uses of the original, illegal value must persist
/// beyond the main conversion.
/// converting a legal type to an illegal source type. This is used when
/// conversions to an illegal type must persist beyond the main conversion.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addSourceMaterialization(FnT &&callback) {
@ -201,7 +198,7 @@ public:
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
/// converting an illegal (source) value to a legal (target) type.
/// converting type from an illegal, or source, type to a legal type.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {

View File

@ -432,14 +432,34 @@ private:
Block *insertBeforeBlock;
};
/// This structure contains the information pertaining to an argument that has
/// been converted.
struct ConvertedArgInfo {
ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
Value castValue = nullptr)
: newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
/// The start index of in the new argument list that contains arguments that
/// replace the original.
unsigned newArgIdx;
/// The number of arguments that replaced the original argument.
unsigned newArgSize;
/// The cast value that was created to cast from the new arguments to the
/// old. This only used if 'newArgSize' > 1.
Value castValue;
};
/// Block type conversion. This rewrite is partially reflected in the IR.
class BlockTypeConversionRewrite : public BlockRewrite {
public:
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block, Block *origBlock,
const TypeConverter *converter)
BlockTypeConversionRewrite(
ConversionPatternRewriterImpl &rewriterImpl, Block *block,
Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
const TypeConverter *converter)
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
origBlock(origBlock), converter(converter) {}
origBlock(origBlock), argInfo(argInfo), converter(converter) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
@ -459,6 +479,10 @@ private:
/// The original block that was requested to have its signature converted.
Block *origBlock;
/// The conversion information for each of the arguments. The information is
/// std::nullopt if the argument was dropped during conversion.
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
/// The type converter used to convert the arguments.
const TypeConverter *converter;
};
@ -672,11 +696,7 @@ enum MaterializationKind {
/// This materialization materializes a conversion from an illegal type to a
/// legal one.
Target,
/// This materialization materializes a conversion from a legal type back to
/// an illegal one.
Source
Target
};
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@ -688,13 +708,9 @@ public:
ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
MaterializationKind kind = MaterializationKind::Target,
Type origArgType = nullptr)
Type origOutputType = nullptr)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
converterAndKind(converter, kind), origArgType(origArgType) {
assert(kind == MaterializationKind::Argument ||
!origArgType && "orginal argument type make sense only for argument "
"materializations");
}
converterAndKind(converter, kind), origOutputType(origOutputType) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@ -718,17 +734,17 @@ public:
return converterAndKind.getInt();
}
/// Return the original type of the block argument.
Type getOrigArgType() const { return origArgType; }
/// Return the original illegal output type of the input values.
Type getOrigOutputType() const { return origOutputType; }
private:
/// The corresponding type converter to use when resolving this
/// materialization, and the kind of this materialization.
llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
converterAndKind;
/// The original output type. This is only used for argument conversions.
Type origArgType;
Type origOutputType;
};
} // namespace
@ -846,6 +862,13 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
ValueRange inputs, Type outputType,
Type origOutputType,
const TypeConverter *converter);
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
ValueRange inputs,
Type origOutputType,
Type outputType,
const TypeConverter *converter);
Value buildUnresolvedTargetMaterialization(Location loc, Value input,
Type outputType,
const TypeConverter *converter);
@ -975,6 +998,28 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
for (Operation *op : block->getUsers())
listener->notifyOperationModified(op);
// Process the remapping for each of the original arguments.
for (auto [origArg, info] :
llvm::zip_equal(origBlock->getArguments(), argInfo)) {
// Handle the case of a 1->0 value mapping.
if (!info) {
if (Value newArg =
rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
rewriter.replaceAllUsesWith(origArg, newArg);
continue;
}
// Otherwise this is a 1->1+ value mapping.
Value castValue = info->castValue;
assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
// If the argument is still used, replace it with the generated cast.
if (!origArg.use_empty()) {
rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
castValue, origArg.getType()));
}
}
}
void BlockTypeConversionRewrite::rollback() {
@ -998,13 +1043,15 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
if (!liveUser)
continue;
Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
assert(replacementValue && "replacement value not found");
Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
bool isDroppedArg = replacementValue == origArg;
if (!isDroppedArg)
builder.setInsertionPointAfterValue(replacementValue);
Value newArg;
if (converter) {
builder.setInsertionPointAfterValue(replacementValue);
newArg = converter->materializeSourceConversion(
builder, origArg.getLoc(), origArg.getType(), replacementValue);
builder, origArg.getLoc(), origArg.getType(),
isDroppedArg ? ValueRange() : ValueRange(replacementValue));
assert((!newArg || newArg.getType() == origArg.getType()) &&
"materialization hook did not provide a value of the expected "
"type");
@ -1015,6 +1062,8 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
<< "failed to materialize conversion for block argument #"
<< it.index() << " that remained live after conversion, type was "
<< origArg.getType();
if (!isDroppedArg)
diag << ", with target type " << replacementValue.getType();
diag.attachNote(liveUser->getLoc())
<< "see existing live user here: " << *liveUser;
return failure();
@ -1300,65 +1349,65 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// Replace all uses of the old block with the new block.
block->replaceAllUsesWith(newBlock);
for (unsigned i = 0; i != origArgCount; ++i) {
BlockArgument origArg = block->getArgument(i);
Type origArgType = origArg.getType();
// Remap each of the original arguments as determined by the signature
// conversion.
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
argInfo.resize(origArgCount);
// Helper function that tries to legalize the given type. Returns the given
// type if it could not be legalized.
for (unsigned i = 0; i != origArgCount; ++i) {
auto inputMap = signatureConversion.getInputMapping(i);
if (!inputMap)
continue;
BlockArgument origArg = block->getArgument(i);
// If inputMap->replacementValue is not nullptr, then the argument is
// dropped and a replacement value is provided to be the remappedValue.
if (inputMap->replacementValue) {
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
mapping.map(origArg, inputMap->replacementValue);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
}
// Otherwise, this is a 1->1+ mapping.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
Value newArg;
// If this is a 1->1 mapping and the types of new and replacement arguments
// match (i.e. it's an identity map), then the argument is mapped to its
// original type.
// FIXME: We simply pass through the replacement argument if there wasn't a
// converter, which isn't great as it allows implicit type conversions to
// appear. We should properly restructure this code to handle cases where a
// converter isn't provided and also to properly handle the case where an
// argument materialization is actually a temporary source materialization
// (e.g. in the case of 1->N).
auto tryLegalizeType = [&](Type type) {
if (converter)
if (Type t = converter->convertType(type))
return t;
return type;
};
if (replArgs.size() == 1 &&
(!converter || replArgs[0].getType() == origArg.getType())) {
newArg = replArgs.front();
} else {
Type origOutputType = origArg.getType();
std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
signatureConversion.getInputMapping(i);
if (!inputMap) {
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
Value repl = buildUnresolvedMaterialization(
MaterializationKind::Source, newBlock, newBlock->begin(),
origArg.getLoc(), /*inputs=*/ValueRange(),
/*outputType=*/origArgType, /*origArgType=*/{}, converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
// Legalize the argument output type.
Type outputType = origOutputType;
if (Type legalOutputType = converter->convertType(outputType))
outputType = legalOutputType;
newArg = buildUnresolvedArgumentMaterialization(
newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
converter);
}
if (Value repl = inputMap->replacementValue) {
// This block argument was dropped and a replacement value was provided.
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
}
// This is a 1->1+ mapping. 1->N mappings are not fully supported in the
// dialect conversion. Therefore, we need an argument materialization to
// turn the replacement block arguments into a single SSA value that can be
// used as a replacement. The type of this SSA value is the legalized
// version of the original block argument type.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
Value repl = buildUnresolvedMaterialization(
MaterializationKind::Argument, newBlock, newBlock->begin(),
origArg.getLoc(), /*inputs=*/replArgs,
/*outputType=*/tryLegalizeType(origArgType), origArgType, converter);
mapping.map(origArg, repl);
mapping.map(origArg, newArg);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
converter);
// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
@ -1375,7 +1424,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
Location loc, ValueRange inputs, Type outputType, Type origArgType,
Location loc, ValueRange inputs, Type outputType, Type origOutputType,
const TypeConverter *converter) {
// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
@ -1387,9 +1436,16 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
origArgType);
origOutputType);
return convertOp.getResult(0);
}
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
Block *block, Location loc, ValueRange inputs, Type origOutputType,
Type outputType, const TypeConverter *converter) {
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
block->begin(), loc, inputs, outputType,
origOutputType, converter);
}
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
Location loc, Value input, Type outputType,
const TypeConverter *converter) {
@ -1398,9 +1454,9 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
if (OpResult inputRes = dyn_cast<OpResult>(input))
insertPt = ++inputRes.getOwner()->getIterator();
return buildUnresolvedMaterialization(
MaterializationKind::Target, insertBlock, insertPt, loc, input,
outputType, /*origArgType=*/{}, converter);
return buildUnresolvedMaterialization(MaterializationKind::Target,
insertBlock, insertPt, loc, input,
outputType, outputType, converter);
}
//===----------------------------------------------------------------------===//
@ -2796,7 +2852,7 @@ static LogicalResult legalizeUnresolvedMaterialization(
// easily misunderstood. We should clean up the argument hooks to better
// represent the desired invariants we actually care about.
newMaterialization = converter->materializeArgumentConversion(
rewriter, op->getLoc(), mat.getOrigArgType(), inputOperands);
rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
if (newMaterialization)
break;
@ -2807,10 +2863,6 @@ static LogicalResult legalizeUnresolvedMaterialization(
newMaterialization = converter->materializeTargetConversion(
rewriter, op->getLoc(), outputType, inputOperands);
break;
case MaterializationKind::Source:
newMaterialization = converter->materializeSourceConversion(
rewriter, op->getLoc(), outputType, inputOperands);
break;
}
if (newMaterialization) {
replaceMaterialization(rewriterImpl, opResult, newMaterialization,
@ -2821,8 +2873,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
InFlightDiagnostic diag = op->emitError()
<< "failed to legalize unresolved materialization "
"from ("
<< inputOperands.getTypes() << ") to " << outputType
"from "
<< inputOperands.getTypes() << " to " << outputType
<< " that remained live after conversion";
if (Operation *liveUser = findLiveUser(op->getUsers())) {
diag.attachNote(liveUser->getLoc())

View File

@ -2,8 +2,9 @@
func.func @test_invalid_arg_materialization(
// expected-error@below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
%arg0: i16) {
// expected-note@below {{see existing live user here}}
"foo.return"(%arg0) : (i16) -> ()
}
@ -103,8 +104,9 @@ func.func @test_block_argument_not_converted() {
// Make sure argument type changes aren't implicitly forwarded.
func.func @test_signature_conversion_no_converter() {
"test.signature_conversion_no_converter"() ({
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
^bb0(%arg0: f32):
// expected-note@below {{see existing live user here}}
"test.type_consumer"(%arg0) : (f32) -> ()
"test.return"(%arg0) : (f32) -> ()
}) : () -> ()