mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-17 18:16:42 +00:00
[flang] AArch64 ABI for BIND(C) VALUE parameters (#118305)
This patch adds handling for derived type VALUE parameters in BIND(C) functions for AArch64.
This commit is contained in:
parent
3666de9c8e
commit
44aa476aa1
@ -788,6 +788,8 @@ struct TargetX86_64Win : public GenericTarget<TargetX86_64Win> {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// AArch64 procedure call standard:
|
||||
// https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing
|
||||
struct TargetAArch64 : public GenericTarget<TargetAArch64> {
|
||||
using GenericTarget::GenericTarget;
|
||||
|
||||
@ -826,7 +828,7 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
|
||||
return marshal;
|
||||
}
|
||||
|
||||
// Flatten a RecordType::TypeList containing more record types or array types
|
||||
// Flatten a RecordType::TypeList containing more record types or array type
|
||||
static std::optional<std::vector<mlir::Type>>
|
||||
flattenTypeList(const RecordType::TypeList &types) {
|
||||
std::vector<mlir::Type> flatTypes;
|
||||
@ -870,52 +872,144 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
|
||||
|
||||
// Determine if the type is a Homogenous Floating-point Aggregate (HFA). An
|
||||
// HFA is a record type with up to 4 floating-point members of the same type.
|
||||
static bool isHFA(fir::RecordType ty) {
|
||||
static std::optional<int> usedRegsForHFA(fir::RecordType ty) {
|
||||
RecordType::TypeList types = ty.getTypeList();
|
||||
if (types.empty() || types.size() > 4)
|
||||
return false;
|
||||
return std::nullopt;
|
||||
|
||||
std::optional<std::vector<mlir::Type>> flatTypes = flattenTypeList(types);
|
||||
if (!flatTypes || flatTypes->size() > 4) {
|
||||
return false;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (!isa_real(flatTypes->front())) {
|
||||
return false;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return llvm::all_equal(*flatTypes);
|
||||
return llvm::all_equal(*flatTypes) ? std::optional<int>{flatTypes->size()}
|
||||
: std::nullopt;
|
||||
}
|
||||
|
||||
struct NRegs {
|
||||
int n{0};
|
||||
bool isSimd{false};
|
||||
};
|
||||
|
||||
NRegs usedRegsForRecordType(mlir::Location loc, fir::RecordType type) const {
|
||||
if (std::optional<int> size = usedRegsForHFA(type))
|
||||
return {*size, true};
|
||||
|
||||
auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash(
|
||||
loc, type, getDataLayout(), kindMap);
|
||||
|
||||
if (size <= 16)
|
||||
return {static_cast<int>((size + 7) / 8), false};
|
||||
|
||||
// Pass on the stack, i.e. no registers used
|
||||
return {};
|
||||
}
|
||||
|
||||
NRegs usedRegsForType(mlir::Location loc, mlir::Type type) const {
|
||||
return llvm::TypeSwitch<mlir::Type, NRegs>(type)
|
||||
.Case<mlir::IntegerType>([&](auto intTy) {
|
||||
return intTy.getWidth() == 128 ? NRegs{2, false} : NRegs{1, false};
|
||||
})
|
||||
.Case<mlir::FloatType>([&](auto) { return NRegs{1, true}; })
|
||||
.Case<mlir::ComplexType>([&](auto) { return NRegs{2, true}; })
|
||||
.Case<fir::LogicalType>([&](auto) { return NRegs{1, false}; })
|
||||
.Case<fir::CharacterType>([&](auto) { return NRegs{1, false}; })
|
||||
.Case<fir::SequenceType>([&](auto ty) {
|
||||
assert(ty.getShape().size() == 1 &&
|
||||
"invalid array dimensions in BIND(C)");
|
||||
NRegs nregs = usedRegsForType(loc, ty.getEleTy());
|
||||
nregs.n *= ty.getShape()[0];
|
||||
return nregs;
|
||||
})
|
||||
.Case<fir::RecordType>(
|
||||
[&](auto ty) { return usedRegsForRecordType(loc, ty); })
|
||||
.Case<fir::VectorType>([&](auto) {
|
||||
TODO(loc, "passing vector argument to C by value is not supported");
|
||||
return NRegs{};
|
||||
});
|
||||
}
|
||||
|
||||
bool hasEnoughRegisters(mlir::Location loc, fir::RecordType type,
|
||||
const Marshalling &previousArguments) const {
|
||||
int availIntRegisters = 8;
|
||||
int availSIMDRegisters = 8;
|
||||
|
||||
// Check previous arguments to see how many registers are used already
|
||||
for (auto [type, attr] : previousArguments) {
|
||||
if (availIntRegisters <= 0 || availSIMDRegisters <= 0)
|
||||
break;
|
||||
|
||||
if (attr.isByVal())
|
||||
continue; // Previous argument passed on the stack
|
||||
|
||||
NRegs nregs = usedRegsForType(loc, type);
|
||||
if (nregs.isSimd)
|
||||
availSIMDRegisters -= nregs.n;
|
||||
else
|
||||
availIntRegisters -= nregs.n;
|
||||
}
|
||||
|
||||
NRegs nregs = usedRegsForRecordType(loc, type);
|
||||
|
||||
if (nregs.isSimd)
|
||||
return nregs.n <= availSIMDRegisters;
|
||||
|
||||
return nregs.n <= availIntRegisters;
|
||||
}
|
||||
|
||||
// AArch64 procedure call ABI:
|
||||
// https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing
|
||||
CodeGenSpecifics::Marshalling
|
||||
structReturnType(mlir::Location loc, fir::RecordType ty) const override {
|
||||
passOnTheStack(mlir::Location loc, mlir::Type ty, bool isResult) const {
|
||||
CodeGenSpecifics::Marshalling marshal;
|
||||
auto sizeAndAlign =
|
||||
fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
|
||||
// The stack is always 8 byte aligned
|
||||
unsigned short align =
|
||||
std::max(sizeAndAlign.second, static_cast<unsigned short>(8));
|
||||
marshal.emplace_back(fir::ReferenceType::get(ty),
|
||||
AT{align, /*byval=*/!isResult, /*sret=*/isResult});
|
||||
return marshal;
|
||||
}
|
||||
|
||||
CodeGenSpecifics::Marshalling
|
||||
structType(mlir::Location loc, fir::RecordType type, bool isResult) const {
|
||||
NRegs nregs = usedRegsForRecordType(loc, type);
|
||||
|
||||
// If the type needs no registers it must need to be passed on the stack
|
||||
if (nregs.n == 0)
|
||||
return passOnTheStack(loc, type, isResult);
|
||||
|
||||
CodeGenSpecifics::Marshalling marshal;
|
||||
|
||||
if (isHFA(ty)) {
|
||||
// Just return the existing record type
|
||||
marshal.emplace_back(ty, AT{});
|
||||
return marshal;
|
||||
mlir::Type pcsType;
|
||||
if (nregs.isSimd) {
|
||||
pcsType = type;
|
||||
} else {
|
||||
pcsType = fir::SequenceType::get(
|
||||
nregs.n, mlir::IntegerType::get(type.getContext(), 64));
|
||||
}
|
||||
|
||||
auto [size, align] =
|
||||
fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
|
||||
|
||||
// return in registers if size <= 16 bytes
|
||||
if (size <= 16) {
|
||||
std::size_t dwordSize = (size + 7) / 8;
|
||||
auto newTy = fir::SequenceType::get(
|
||||
dwordSize, mlir::IntegerType::get(ty.getContext(), 64));
|
||||
marshal.emplace_back(newTy, AT{});
|
||||
return marshal;
|
||||
}
|
||||
|
||||
unsigned short stackAlign = std::max<unsigned short>(align, 8u);
|
||||
marshal.emplace_back(fir::ReferenceType::get(ty),
|
||||
AT{stackAlign, false, true});
|
||||
marshal.emplace_back(pcsType, AT{});
|
||||
return marshal;
|
||||
}
|
||||
|
||||
CodeGenSpecifics::Marshalling
|
||||
structArgumentType(mlir::Location loc, fir::RecordType ty,
|
||||
const Marshalling &previousArguments) const override {
|
||||
if (!hasEnoughRegisters(loc, ty, previousArguments)) {
|
||||
return passOnTheStack(loc, ty, /*isResult=*/false);
|
||||
}
|
||||
|
||||
return structType(loc, ty, /*isResult=*/false);
|
||||
}
|
||||
|
||||
CodeGenSpecifics::Marshalling
|
||||
structReturnType(mlir::Location loc, fir::RecordType ty) const override {
|
||||
return structType(loc, ty, /*isResult=*/true);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
73
flang/test/Fir/struct-passing-aarch64-byval.fir
Normal file
73
flang/test/Fir/struct-passing-aarch64-byval.fir
Normal file
@ -0,0 +1,73 @@
|
||||
// Test AArch64 ABI rewrite of struct passed by value (BIND(C), VALUE derived types).
|
||||
// RUN: fir-opt --target-rewrite="target=aarch64-unknown-linux-gnu" %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func private @small_i32(!fir.array<2xi64>)
|
||||
func.func private @small_i32(!fir.type<small_i32{i:i32,j:i32,k:i32}>)
|
||||
// CHECK-LABEL: func.func private @small_i64(!fir.array<2xi64>)
|
||||
func.func private @small_i64(!fir.type<small_i64{i:i64,j:i64}>)
|
||||
// CHECK-LABEL: func.func private @small_mixed(!fir.array<2xi64>)
|
||||
func.func private @small_mixed(!fir.type<small_mixed{i:i64,j:f32,k:i32}>)
|
||||
// CHECK-LABEL: func.func private @small_non_hfa(!fir.array<2xi64>)
|
||||
func.func private @small_non_hfa(!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>)
|
||||
|
||||
// CHECK-LABEL: func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>)
|
||||
func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>)
|
||||
// CHECK-LABEL: func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>)
|
||||
func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>)
|
||||
// CHECK-LABEL: func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>)
|
||||
func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>)
|
||||
// CHECK-LABEL: func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>)
|
||||
func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>)
|
||||
// CHECK-LABEL: func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
|
||||
func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
|
||||
|
||||
// CHECK-LABEL: func.func private @multi_small_integer(!fir.array<2xi64>, !fir.array<2xi64>)
|
||||
func.func private @multi_small_integer(!fir.type<small_i32{i:i32,j:i32,k:i32}>, !fir.type<small_i64{i:i64,j:i64}>)
|
||||
// CHECK-LABEL: func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
|
||||
func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
|
||||
// CHECK-LABEL: func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>, !fir.array<2xi64>, !fir.type<hfa_f32{i:f32,j:f32}>, !fir.array<2xi64>)
|
||||
func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>,!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>,!fir.type<hfa_f32{i:f32,j:f32}>,!fir.type<small_i64{i:i64,j:i64}>)
|
||||
|
||||
// CHECK-LABEL: func.func private @int_max(!fir.array<2xi64>,
|
||||
// CHECK-SAME: !fir.array<2xi64>,
|
||||
// CHECK-SAME: !fir.array<2xi64>,
|
||||
// CHECK-SAME: !fir.array<2xi64>)
|
||||
func.func private @int_max(!fir.type<int_max{i:i64,j:i64}>,
|
||||
!fir.type<int_max{i:i64,j:i64}>,
|
||||
!fir.type<int_max{i:i64,j:i64}>,
|
||||
!fir.type<int_max{i:i64,j:i64}>)
|
||||
// CHECK-LABEL: func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
|
||||
func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
|
||||
// CHECK-LABEL: func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
|
||||
// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
|
||||
// CHECK-SAME: !fir.array<2xi64>,
|
||||
// CHECK-SAME: !fir.array<2xi64>,
|
||||
// CHECK-SAME: !fir.array<2xi64>,
|
||||
// CHECK-SAME: !fir.array<2xi64>)
|
||||
func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
|
||||
!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
|
||||
!fir.type<int_max{i:i64,j:i64}>,
|
||||
!fir.type<int_max{i:i64,j:i64}>,
|
||||
!fir.type<int_max{i:i64,j:i64}>,
|
||||
!fir.type<int_max{i:i64,j:i64}>)
|
||||
|
||||
|
||||
// CHECK-LABEL: func.func private @too_many_int(!fir.array<2xi64>,
|
||||
// CHECK-SAME: !fir.array<2xi64>,
|
||||
// CHECK-SAME: !fir.array<2xi64>,
|
||||
// CHECK-SAME: !fir.array<2xi64>,
|
||||
// CHECK-SAME: !fir.ref<!fir.type<int_max{i:i64,j:i64}>> {{{.*}}, llvm.byval = !fir.type<int_max{i:i64,j:i64}>})
|
||||
func.func private @too_many_int(!fir.type<int_max{i:i64,j:i64}>,
|
||||
!fir.type<int_max{i:i64,j:i64}>,
|
||||
!fir.type<int_max{i:i64,j:i64}>,
|
||||
!fir.type<int_max{i:i64,j:i64}>,
|
||||
!fir.type<int_max{i:i64,j:i64}>)
|
||||
// CHECK-LABEL: func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
|
||||
// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
|
||||
// CHECK-SAME: !fir.ref<!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>> {{{.*}}, llvm.byval = !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>})
|
||||
func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
|
||||
!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
|
||||
!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
|
||||
|
||||
// CHECK-LABEL: func.func private @too_big(!fir.ref<!fir.type<too_big{i:!fir.array<5xi32>}>> {{{.*}}, llvm.byval = !fir.type<too_big{i:!fir.array<5xi32>}>})
|
||||
func.func private @too_big(!fir.type<too_big{i:!fir.array<5xi32>}>)
|
Loading…
x
Reference in New Issue
Block a user