//===- SemaHLSL.cpp - Semantic Analysis for HLSL constructs ---------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // This implements Semantic Analysis for HLSL constructs. //===----------------------------------------------------------------------===// #include "clang/Sema/SemaHLSL.h" #include "clang/Basic/DiagnosticSema.h" #include "clang/Basic/LLVM.h" #include "clang/Basic/TargetInfo.h" #include "clang/Sema/Sema.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/TargetParser/Triple.h" #include using namespace clang; SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {} Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer, SourceLocation KwLoc, IdentifierInfo *Ident, SourceLocation IdentLoc, SourceLocation LBrace) { // For anonymous namespace, take the location of the left brace. DeclContext *LexicalParent = SemaRef.getCurLexicalContext(); HLSLBufferDecl *Result = HLSLBufferDecl::Create( getASTContext(), LexicalParent, CBuffer, KwLoc, Ident, IdentLoc, LBrace); SemaRef.PushOnScopeChains(Result, BufferScope); SemaRef.PushDeclContext(BufferScope, Result); return Result; } void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) { auto *BufDecl = cast(Dcl); BufDecl->setRBraceLoc(RBrace); SemaRef.PopDeclContext(); } HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D, const AttributeCommonInfo &AL, int X, int Y, int Z) { if (HLSLNumThreadsAttr *NT = D->getAttr()) { if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) { Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; Diag(AL.getLoc(), diag::note_conflicting_attribute); } return nullptr; } return ::new (getASTContext()) HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z); } HLSLShaderAttr * SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, HLSLShaderAttr::ShaderType ShaderType) { if (HLSLShaderAttr *NT = D->getAttr()) { if (NT->getType() != ShaderType) { Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; Diag(AL.getLoc(), diag::note_conflicting_attribute); } return nullptr; } return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL); } HLSLParamModifierAttr * SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, HLSLParamModifierAttr::Spelling Spelling) { // We can only merge an `in` attribute with an `out` attribute. All other // combinations of duplicated attributes are ill-formed. if (HLSLParamModifierAttr *PA = D->getAttr()) { if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) || (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) { D->dropAttr(); SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()}; return HLSLParamModifierAttr::Create( getASTContext(), /*MergedSpelling=*/true, AdjustedRange, HLSLParamModifierAttr::Keyword_inout); } Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL; Diag(PA->getLocation(), diag::note_conflicting_attribute); return nullptr; } return HLSLParamModifierAttr::Create(getASTContext(), AL); } void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) { auto &TargetInfo = getASTContext().getTargetInfo(); if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry) return; StringRef Env = TargetInfo.getTriple().getEnvironmentName(); HLSLShaderAttr::ShaderType ShaderType; if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) { if (const auto *Shader = FD->getAttr()) { // The entry point is already annotated - check that it matches the // triple. if (Shader->getType() != ShaderType) { Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch) << Shader; FD->setInvalidDecl(); } } else { // Implicitly add the shader attribute if the entry function isn't // explicitly annotated. FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), ShaderType, FD->getBeginLoc())); } } else { switch (TargetInfo.getTriple().getEnvironment()) { case llvm::Triple::UnknownEnvironment: case llvm::Triple::Library: break; default: llvm_unreachable("Unhandled environment in triple"); } } } void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) { const auto *ShaderAttr = FD->getAttr(); assert(ShaderAttr && "Entry point has no shader attribute"); HLSLShaderAttr::ShaderType ST = ShaderAttr->getType(); switch (ST) { case HLSLShaderAttr::Pixel: case HLSLShaderAttr::Vertex: case HLSLShaderAttr::Geometry: case HLSLShaderAttr::Hull: case HLSLShaderAttr::Domain: case HLSLShaderAttr::RayGeneration: case HLSLShaderAttr::Intersection: case HLSLShaderAttr::AnyHit: case HLSLShaderAttr::ClosestHit: case HLSLShaderAttr::Miss: case HLSLShaderAttr::Callable: if (const auto *NT = FD->getAttr()) { DiagnoseAttrStageMismatch(NT, ST, {HLSLShaderAttr::Compute, HLSLShaderAttr::Amplification, HLSLShaderAttr::Mesh}); FD->setInvalidDecl(); } break; case HLSLShaderAttr::Compute: case HLSLShaderAttr::Amplification: case HLSLShaderAttr::Mesh: if (!FD->hasAttr()) { Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads) << HLSLShaderAttr::ConvertShaderTypeToStr(ST); FD->setInvalidDecl(); } break; } for (ParmVarDecl *Param : FD->parameters()) { if (const auto *AnnotationAttr = Param->getAttr()) { CheckSemanticAnnotation(FD, Param, AnnotationAttr); } else { // FIXME: Handle struct parameters where annotations are on struct fields. // See: https://github.com/llvm/llvm-project/issues/57875 Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation); Diag(Param->getLocation(), diag::note_previous_decl) << Param; FD->setInvalidDecl(); } } // FIXME: Verify return type semantic annotation. } void SemaHLSL::CheckSemanticAnnotation( FunctionDecl *EntryPoint, const Decl *Param, const HLSLAnnotationAttr *AnnotationAttr) { auto *ShaderAttr = EntryPoint->getAttr(); assert(ShaderAttr && "Entry point has no shader attribute"); HLSLShaderAttr::ShaderType ST = ShaderAttr->getType(); switch (AnnotationAttr->getKind()) { case attr::HLSLSV_DispatchThreadID: case attr::HLSLSV_GroupIndex: if (ST == HLSLShaderAttr::Compute) return; DiagnoseAttrStageMismatch(AnnotationAttr, ST, {HLSLShaderAttr::Compute}); break; default: llvm_unreachable("Unknown HLSLAnnotationAttr"); } } void SemaHLSL::DiagnoseAttrStageMismatch( const Attr *A, HLSLShaderAttr::ShaderType Stage, std::initializer_list AllowedStages) { SmallVector StageStrings; llvm::transform(AllowedStages, std::back_inserter(StageStrings), [](HLSLShaderAttr::ShaderType ST) { return StringRef( HLSLShaderAttr::ConvertShaderTypeToStr(ST)); }); Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage) << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage) << (AllowedStages.size() != 1) << join(StageStrings, ", "); }