llvm-project/clang/lib/Serialization/TemplateArgumentHasher.cpp
Chuanqi Xu 20e9049509
[Serialization] Support loading template specializations lazily (#119333)
Reland https://github.com/llvm/llvm-project/pull/83237

---

(Original comments)

Currently all the specializations of a template (including
instantiation, specialization and partial specializations) will be
loaded at once if we want to instantiate another instance for the
template, or find instantiation for the template, or just want to
complete the redecl chain.

This means basically we need to load every specializations for the
template once the template declaration got loaded. This is bad since
when we load a specialization, we need to load all of its template
arguments. Then we have to deserialize a lot of unnecessary
declarations.

For example,

```
// M.cppm
export module M;
export template <class T>
class A {};

export class ShouldNotBeLoaded {};

export class Temp {
   A<ShouldNotBeLoaded> AS;
};

// use.cpp
import M;
A<int> a;
```

We have a specialization ` A<ShouldNotBeLoaded>` in `M.cppm` and we
instantiate the template `A` in `use.cpp`. Then we will deserialize
`ShouldNotBeLoaded` surprisingly when compiling `use.cpp`. And this
patch tries to avoid that.

Given that the templates are heavily used in C++, this is a pain point
for the performance.

This patch adds MultiOnDiskHashTable for specializations in the
ASTReader. Then we will only deserialize the specializations with the
same template arguments. We made that by using ODRHash for the template
arguments as the key of the hash table.

To review this patch, I think `ASTReaderDecl::AddLazySpecializations`
may be a good entry point.
2024-12-11 09:40:47 +08:00

410 lines
12 KiB
C++

//===- TemplateArgumentHasher.cpp - Hash Template Arguments -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TemplateArgumentHasher.h"
#include "clang/AST/APValue.h"
#include "clang/AST/Decl.h"
#include "clang/AST/DeclCXX.h"
#include "clang/AST/DeclTemplate.h"
#include "clang/AST/DeclarationName.h"
#include "clang/AST/TypeVisitor.h"
#include "clang/Basic/IdentifierTable.h"
#include "llvm/ADT/FoldingSet.h"
using namespace clang;
namespace {
class TemplateArgumentHasher {
// If we bail out during the process of calculating hash values for
// template arguments for any reason. We're allowed to do it since
// TemplateArgumentHasher are only required to give the same hash value
// for the same template arguments, but not required to give different
// hash value for different template arguments.
//
// So in the worst case, it is still a valid implementation to give all
// inputs the same BailedOutValue as output.
bool BailedOut = false;
static constexpr unsigned BailedOutValue = 0x12345678;
llvm::FoldingSetNodeID ID;
public:
TemplateArgumentHasher() = default;
void AddTemplateArgument(TemplateArgument TA);
void AddInteger(unsigned V) { ID.AddInteger(V); }
unsigned getValue() {
if (BailedOut)
return BailedOutValue;
return ID.computeStableHash();
}
void setBailedOut() { BailedOut = true; }
void AddType(const Type *T);
void AddQualType(QualType T);
void AddDecl(const Decl *D);
void AddStructuralValue(const APValue &);
void AddTemplateName(TemplateName Name);
void AddDeclarationName(DeclarationName Name);
void AddIdentifierInfo(const IdentifierInfo *II);
};
void TemplateArgumentHasher::AddTemplateArgument(TemplateArgument TA) {
const auto Kind = TA.getKind();
AddInteger(Kind);
switch (Kind) {
case TemplateArgument::Null:
llvm_unreachable("Expected valid TemplateArgument");
case TemplateArgument::Type:
AddQualType(TA.getAsType());
break;
case TemplateArgument::Declaration:
AddDecl(TA.getAsDecl());
break;
case TemplateArgument::NullPtr:
ID.AddPointer(nullptr);
break;
case TemplateArgument::Integral: {
// There are integrals (e.g.: _BitInt(128)) that cannot be represented as
// any builtin integral type, so we use the hash of APSInt instead.
TA.getAsIntegral().Profile(ID);
break;
}
case TemplateArgument::StructuralValue:
AddQualType(TA.getStructuralValueType());
AddStructuralValue(TA.getAsStructuralValue());
break;
case TemplateArgument::Template:
case TemplateArgument::TemplateExpansion:
AddTemplateName(TA.getAsTemplateOrTemplatePattern());
break;
case TemplateArgument::Expression:
// If we meet expression in template argument, it implies
// that the template is still dependent. It is meaningless
// to get a stable hash for the template. Bail out simply.
BailedOut = true;
break;
case TemplateArgument::Pack:
AddInteger(TA.pack_size());
for (auto SubTA : TA.pack_elements()) {
AddTemplateArgument(SubTA);
}
break;
}
}
void TemplateArgumentHasher::AddStructuralValue(const APValue &Value) {
auto Kind = Value.getKind();
AddInteger(Kind);
// 'APValue::Profile' uses pointer values to make hash for LValue and
// MemberPointer, but they differ from one compiler invocation to another.
// It may be difficult to handle such cases. Bail out simply.
if (Kind == APValue::LValue || Kind == APValue::MemberPointer) {
BailedOut = true;
return;
}
Value.Profile(ID);
}
void TemplateArgumentHasher::AddTemplateName(TemplateName Name) {
switch (Name.getKind()) {
case TemplateName::Template:
AddDecl(Name.getAsTemplateDecl());
break;
case TemplateName::QualifiedTemplate: {
QualifiedTemplateName *QTN = Name.getAsQualifiedTemplateName();
AddTemplateName(QTN->getUnderlyingTemplate());
break;
}
case TemplateName::OverloadedTemplate:
case TemplateName::AssumedTemplate:
case TemplateName::DependentTemplate:
case TemplateName::SubstTemplateTemplateParm:
case TemplateName::SubstTemplateTemplateParmPack:
BailedOut = true;
break;
case TemplateName::UsingTemplate: {
UsingShadowDecl *USD = Name.getAsUsingShadowDecl();
if (USD)
AddDecl(USD->getTargetDecl());
else
BailedOut = true;
break;
}
case TemplateName::DeducedTemplate:
AddTemplateName(Name.getAsDeducedTemplateName()->getUnderlying());
break;
}
}
void TemplateArgumentHasher::AddIdentifierInfo(const IdentifierInfo *II) {
assert(II && "Expecting non-null pointer.");
ID.AddString(II->getName());
}
void TemplateArgumentHasher::AddDeclarationName(DeclarationName Name) {
if (Name.isEmpty())
return;
switch (Name.getNameKind()) {
case DeclarationName::Identifier:
AddIdentifierInfo(Name.getAsIdentifierInfo());
break;
case DeclarationName::ObjCZeroArgSelector:
case DeclarationName::ObjCOneArgSelector:
case DeclarationName::ObjCMultiArgSelector:
BailedOut = true;
break;
case DeclarationName::CXXConstructorName:
case DeclarationName::CXXDestructorName:
AddQualType(Name.getCXXNameType());
break;
case DeclarationName::CXXOperatorName:
AddInteger(Name.getCXXOverloadedOperator());
break;
case DeclarationName::CXXLiteralOperatorName:
AddIdentifierInfo(Name.getCXXLiteralIdentifier());
break;
case DeclarationName::CXXConversionFunctionName:
AddQualType(Name.getCXXNameType());
break;
case DeclarationName::CXXUsingDirective:
break;
case DeclarationName::CXXDeductionGuideName: {
if (auto *Template = Name.getCXXDeductionGuideTemplate())
AddDecl(Template);
}
}
}
void TemplateArgumentHasher::AddDecl(const Decl *D) {
const NamedDecl *ND = dyn_cast<NamedDecl>(D);
if (!ND) {
BailedOut = true;
return;
}
AddDeclarationName(ND->getDeclName());
}
void TemplateArgumentHasher::AddQualType(QualType T) {
if (T.isNull()) {
BailedOut = true;
return;
}
SplitQualType split = T.split();
AddInteger(split.Quals.getAsOpaqueValue());
AddType(split.Ty);
}
// Process a Type pointer. Add* methods call back into TemplateArgumentHasher
// while Visit* methods process the relevant parts of the Type.
// Any unhandled type will make the hash computation bail out.
class TypeVisitorHelper : public TypeVisitor<TypeVisitorHelper> {
typedef TypeVisitor<TypeVisitorHelper> Inherited;
llvm::FoldingSetNodeID &ID;
TemplateArgumentHasher &Hash;
public:
TypeVisitorHelper(llvm::FoldingSetNodeID &ID, TemplateArgumentHasher &Hash)
: ID(ID), Hash(Hash) {}
void AddDecl(const Decl *D) {
if (D)
Hash.AddDecl(D);
else
Hash.AddInteger(0);
}
void AddQualType(QualType T) { Hash.AddQualType(T); }
void AddType(const Type *T) {
if (T)
Hash.AddType(T);
else
Hash.AddInteger(0);
}
void VisitQualifiers(Qualifiers Quals) {
Hash.AddInteger(Quals.getAsOpaqueValue());
}
void Visit(const Type *T) { Inherited::Visit(T); }
// Unhandled types. Bail out simply.
void VisitType(const Type *T) { Hash.setBailedOut(); }
void VisitAdjustedType(const AdjustedType *T) {
AddQualType(T->getOriginalType());
}
void VisitDecayedType(const DecayedType *T) {
// getDecayedType and getPointeeType are derived from getAdjustedType
// and don't need to be separately processed.
VisitAdjustedType(T);
}
void VisitArrayType(const ArrayType *T) {
AddQualType(T->getElementType());
Hash.AddInteger(llvm::to_underlying(T->getSizeModifier()));
VisitQualifiers(T->getIndexTypeQualifiers());
}
void VisitConstantArrayType(const ConstantArrayType *T) {
T->getSize().Profile(ID);
VisitArrayType(T);
}
void VisitAttributedType(const AttributedType *T) {
Hash.AddInteger(T->getAttrKind());
AddQualType(T->getModifiedType());
}
void VisitBuiltinType(const BuiltinType *T) { Hash.AddInteger(T->getKind()); }
void VisitComplexType(const ComplexType *T) {
AddQualType(T->getElementType());
}
void VisitDecltypeType(const DecltypeType *T) {
AddQualType(T->getUnderlyingType());
}
void VisitDeducedType(const DeducedType *T) {
AddQualType(T->getDeducedType());
}
void VisitAutoType(const AutoType *T) { VisitDeducedType(T); }
void VisitDeducedTemplateSpecializationType(
const DeducedTemplateSpecializationType *T) {
Hash.AddTemplateName(T->getTemplateName());
VisitDeducedType(T);
}
void VisitFunctionType(const FunctionType *T) {
AddQualType(T->getReturnType());
T->getExtInfo().Profile(ID);
Hash.AddInteger(T->isConst());
Hash.AddInteger(T->isVolatile());
Hash.AddInteger(T->isRestrict());
}
void VisitFunctionNoProtoType(const FunctionNoProtoType *T) {
VisitFunctionType(T);
}
void VisitFunctionProtoType(const FunctionProtoType *T) {
Hash.AddInteger(T->getNumParams());
for (auto ParamType : T->getParamTypes())
AddQualType(ParamType);
VisitFunctionType(T);
}
void VisitMemberPointerType(const MemberPointerType *T) {
AddQualType(T->getPointeeType());
AddType(T->getClass());
}
void VisitPackExpansionType(const PackExpansionType *T) {
AddQualType(T->getPattern());
}
void VisitParenType(const ParenType *T) { AddQualType(T->getInnerType()); }
void VisitPointerType(const PointerType *T) {
AddQualType(T->getPointeeType());
}
void VisitReferenceType(const ReferenceType *T) {
AddQualType(T->getPointeeTypeAsWritten());
}
void VisitLValueReferenceType(const LValueReferenceType *T) {
VisitReferenceType(T);
}
void VisitRValueReferenceType(const RValueReferenceType *T) {
VisitReferenceType(T);
}
void
VisitSubstTemplateTypeParmPackType(const SubstTemplateTypeParmPackType *T) {
AddDecl(T->getAssociatedDecl());
Hash.AddTemplateArgument(T->getArgumentPack());
}
void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *T) {
AddDecl(T->getAssociatedDecl());
AddQualType(T->getReplacementType());
}
void VisitTagType(const TagType *T) { AddDecl(T->getDecl()); }
void VisitRecordType(const RecordType *T) { VisitTagType(T); }
void VisitEnumType(const EnumType *T) { VisitTagType(T); }
void VisitTemplateSpecializationType(const TemplateSpecializationType *T) {
Hash.AddInteger(T->template_arguments().size());
for (const auto &TA : T->template_arguments()) {
Hash.AddTemplateArgument(TA);
}
Hash.AddTemplateName(T->getTemplateName());
}
void VisitTemplateTypeParmType(const TemplateTypeParmType *T) {
Hash.AddInteger(T->getDepth());
Hash.AddInteger(T->getIndex());
Hash.AddInteger(T->isParameterPack());
}
void VisitTypedefType(const TypedefType *T) { AddDecl(T->getDecl()); }
void VisitElaboratedType(const ElaboratedType *T) {
AddQualType(T->getNamedType());
}
void VisitUnaryTransformType(const UnaryTransformType *T) {
AddQualType(T->getUnderlyingType());
AddQualType(T->getBaseType());
}
void VisitVectorType(const VectorType *T) {
AddQualType(T->getElementType());
Hash.AddInteger(T->getNumElements());
Hash.AddInteger(llvm::to_underlying(T->getVectorKind()));
}
void VisitExtVectorType(const ExtVectorType *T) { VisitVectorType(T); }
};
void TemplateArgumentHasher::AddType(const Type *T) {
assert(T && "Expecting non-null pointer.");
TypeVisitorHelper(ID, *this).Visit(T);
}
} // namespace
unsigned clang::serialization::StableHashForTemplateArguments(
llvm::ArrayRef<TemplateArgument> Args) {
TemplateArgumentHasher Hasher;
Hasher.AddInteger(Args.size());
for (TemplateArgument Arg : Args)
Hasher.AddTemplateArgument(Arg);
return Hasher.getValue();
}