llvm-project/clang/lib/Sema/SemaSPIRV.cpp
Kaitlin Peng 40c1d50024
[HLSL] Implement the smoothstep intrinsic (#132288)
Closes #99156.


Tasks completed:
- Implement `smoothstep` using HLSL source in `hlsl_intrinsics.h`
- Implement the `smoothstep` SPIR-V target built-in in
`clang/include/clang/Basic/BuiltinsSPIRV.td`
- Add sema checks for `smoothstep` to `CheckSPIRVBuiltinFunctionCall` in
`clang/lib/Sema/SemaSPIRV.cpp`
- Add codegen for spv `smoothstep` to `EmitSPIRVBuiltinExpr` in
`clang/lib/CodeGen/TargetBuiltins/SPIR.cpp`
- Add codegen tests to `clang/test/CodeGenHLSL/builtins/smoothstep.hlsl`
- Add spv codegen test to
`clang/test/CodeGenSPIRV/Builtins/smoothstep.c`
- Add sema tests to
`clang/test/SemaHLSL/BuiltIns/smoothstep-errors.hlsl`
- Add spv sema tests to
`clang/test/SemaSPIRV/BuiltIns/smoothstep-errors.c`
- Create the `int_spv_smoothstep` intrinsic in `IntrinsicsSPIRV.td`
- In SPIRVInstructionSelector.cpp create the `smoothstep` lowering and
map it to `int_spv_smoothstep` in
`SPIRVInstructionSelector::selectIntrinsic`
- Create SPIR-V backend test case in
`llvm/test/CodeGen/SPIRV/hlsl-intrinsics/smoothstep.ll`
- Create SPIR-V backend test case in
`llvm/test/CodeGen/SPIRV/opencl/smoothstep.ll`
2025-03-29 00:47:25 -04:00

144 lines
4.8 KiB
C++

//===- SemaSPIRV.cpp - Semantic Analysis for SPIRV constructs--------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This implements Semantic Analysis for SPIRV constructs.
//===----------------------------------------------------------------------===//
#include "clang/Sema/SemaSPIRV.h"
#include "clang/Basic/TargetBuiltins.h"
#include "clang/Sema/Sema.h"
namespace clang {
SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {}
bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
CallExpr *TheCall) {
switch (BuiltinID) {
case SPIRV::BI__builtin_spirv_distance: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;
ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
auto *VTyA = ArgTyA->getAs<VectorType>();
if (VTyA == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyA
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}
ExprResult B = TheCall->getArg(1);
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}
QualType RetTy = VTyA->getElementType();
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_length: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;
ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
auto *VTy = ArgTyA->getAs<VectorType>();
if (VTy == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyA
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}
QualType RetTy = VTy->getElementType();
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_reflect: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;
ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
auto *VTyA = ArgTyA->getAs<VectorType>();
if (VTyA == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyA
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}
ExprResult B = TheCall->getArg(1);
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}
QualType RetTy = ArgTyA;
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_smoothstep: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;
// check if the all arguments have floating representation
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
ExprResult Arg = TheCall->getArg(i);
QualType ArgTy = Arg.get()->getType();
if (!ArgTy->hasFloatingRepresentation()) {
SemaRef.Diag(Arg.get()->getBeginLoc(),
diag::err_builtin_invalid_arg_type)
<< i + 1 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
<< ArgTy;
return true;
}
}
// check if all arguments are of the same type
ExprResult A = TheCall->getArg(0);
ExprResult B = TheCall->getArg(1);
ExprResult C = TheCall->getArg(2);
if (!(SemaRef.getASTContext().hasSameUnqualifiedType(A.get()->getType(),
B.get()->getType()) &&
SemaRef.getASTContext().hasSameUnqualifiedType(A.get()->getType(),
C.get()->getType()))) {
SemaRef.Diag(TheCall->getBeginLoc(),
diag::err_vec_builtin_incompatible_vector)
<< TheCall->getDirectCallee() << /*useAllTerminology*/ true
<< SourceRange(A.get()->getBeginLoc(), C.get()->getEndLoc());
return true;
}
QualType RetTy = A.get()->getType();
TheCall->setType(RetTy);
break;
}
}
return false;
}
} // namespace clang