mirror of
https://github.com/llvm/llvm-project.git
synced 2025-04-28 16:46:09 +00:00

Handlers returning void previously caused compile errors. Fix that by substituting SPSEmpty placeholder values.
368 lines
13 KiB
C++
368 lines
13 KiB
C++
//===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file is a part of the ORC runtime support library.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H
|
|
#define ORC_RT_WRAPPER_FUNCTION_UTILS_H
|
|
|
|
#include "c_api.h"
|
|
#include "common.h"
|
|
#include "error.h"
|
|
#include "simple_packed_serialization.h"
|
|
#include <type_traits>
|
|
|
|
namespace __orc_rt {
|
|
|
|
/// C++ wrapper function result: Same as CWrapperFunctionResult but
|
|
/// auto-releases memory.
|
|
class WrapperFunctionResult {
|
|
public:
|
|
/// Create a default WrapperFunctionResult.
|
|
WrapperFunctionResult() { __orc_rt_CWrapperFunctionResultInit(&R); }
|
|
|
|
/// Create a WrapperFunctionResult from a CWrapperFunctionResult. This
|
|
/// instance takes ownership of the result object and will automatically
|
|
/// call dispose on the result upon destruction.
|
|
WrapperFunctionResult(__orc_rt_CWrapperFunctionResult R) : R(R) {}
|
|
|
|
WrapperFunctionResult(const WrapperFunctionResult &) = delete;
|
|
WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
|
|
|
|
WrapperFunctionResult(WrapperFunctionResult &&Other) {
|
|
__orc_rt_CWrapperFunctionResultInit(&R);
|
|
std::swap(R, Other.R);
|
|
}
|
|
|
|
WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
|
|
__orc_rt_CWrapperFunctionResult Tmp;
|
|
__orc_rt_CWrapperFunctionResultInit(&Tmp);
|
|
std::swap(Tmp, Other.R);
|
|
std::swap(R, Tmp);
|
|
return *this;
|
|
}
|
|
|
|
~WrapperFunctionResult() { __orc_rt_DisposeCWrapperFunctionResult(&R); }
|
|
|
|
/// Relinquish ownership of and return the
|
|
/// __orc_rt_CWrapperFunctionResult.
|
|
__orc_rt_CWrapperFunctionResult release() {
|
|
__orc_rt_CWrapperFunctionResult Tmp;
|
|
__orc_rt_CWrapperFunctionResultInit(&Tmp);
|
|
std::swap(R, Tmp);
|
|
return Tmp;
|
|
}
|
|
|
|
/// Get a pointer to the data contained in this instance.
|
|
const char *data() const { return __orc_rt_CWrapperFunctionResultData(&R); }
|
|
|
|
/// Returns the size of the data contained in this instance.
|
|
size_t size() const { return __orc_rt_CWrapperFunctionResultSize(&R); }
|
|
|
|
/// Returns true if this value is equivalent to a default-constructed
|
|
/// WrapperFunctionResult.
|
|
bool empty() const { return __orc_rt_CWrapperFunctionResultEmpty(&R); }
|
|
|
|
/// Create a WrapperFunctionResult with the given size and return a pointer
|
|
/// to the underlying memory.
|
|
static char *allocate(WrapperFunctionResult &R, size_t Size) {
|
|
__orc_rt_DisposeCWrapperFunctionResult(&R.R);
|
|
__orc_rt_CWrapperFunctionResultInit(&R.R);
|
|
return __orc_rt_CWrapperFunctionResultAllocate(&R.R, Size);
|
|
}
|
|
|
|
/// Copy from the given char range.
|
|
static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
|
|
return __orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size);
|
|
}
|
|
|
|
/// Copy from the given null-terminated string (includes the null-terminator).
|
|
static WrapperFunctionResult copyFrom(const char *Source) {
|
|
return __orc_rt_CreateCWrapperFunctionResultFromString(Source);
|
|
}
|
|
|
|
/// Copy from the given std::string (includes the null terminator).
|
|
static WrapperFunctionResult copyFrom(const std::string &Source) {
|
|
return copyFrom(Source.c_str());
|
|
}
|
|
|
|
/// Create an out-of-band error by copying the given string.
|
|
static WrapperFunctionResult createOutOfBandError(const char *Msg) {
|
|
return __orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg);
|
|
}
|
|
|
|
/// Create an out-of-band error by copying the given string.
|
|
static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
|
|
return createOutOfBandError(Msg.c_str());
|
|
}
|
|
|
|
/// If this value is an out-of-band error then this returns the error message,
|
|
/// otherwise returns nullptr.
|
|
const char *getOutOfBandError() const {
|
|
return __orc_rt_CWrapperFunctionResultGetOutOfBandError(&R);
|
|
}
|
|
|
|
private:
|
|
__orc_rt_CWrapperFunctionResult R;
|
|
};
|
|
|
|
namespace detail {
|
|
|
|
template <typename SPSArgListT, typename... ArgTs>
|
|
Expected<WrapperFunctionResult>
|
|
serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
|
|
WrapperFunctionResult Result;
|
|
char *DataPtr =
|
|
WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...));
|
|
SPSOutputBuffer OB(DataPtr, Result.size());
|
|
if (!SPSArgListT::serialize(OB, Args...))
|
|
return make_error<StringError>(
|
|
"Error serializing arguments to blob in call");
|
|
return Result;
|
|
}
|
|
|
|
template <typename RetT> class WrapperFunctionHandlerCaller {
|
|
public:
|
|
template <typename HandlerT, typename ArgTupleT, std::size_t... I>
|
|
static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
|
|
std::index_sequence<I...>) {
|
|
return std::forward<HandlerT>(H)(std::get<I>(Args)...);
|
|
}
|
|
};
|
|
|
|
template <> class WrapperFunctionHandlerCaller<void> {
|
|
public:
|
|
template <typename HandlerT, typename ArgTupleT, std::size_t... I>
|
|
static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
|
|
std::index_sequence<I...>) {
|
|
std::forward<HandlerT>(H)(std::get<I>(Args)...);
|
|
return SPSEmpty();
|
|
}
|
|
};
|
|
|
|
template <typename WrapperFunctionImplT,
|
|
template <typename> class ResultSerializer, typename... SPSTagTs>
|
|
class WrapperFunctionHandlerHelper
|
|
: public WrapperFunctionHandlerHelper<
|
|
decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
|
|
ResultSerializer, SPSTagTs...> {};
|
|
|
|
template <typename RetT, typename... ArgTs,
|
|
template <typename> class ResultSerializer, typename... SPSTagTs>
|
|
class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
|
|
SPSTagTs...> {
|
|
public:
|
|
using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
|
|
using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
|
|
|
|
template <typename HandlerT>
|
|
static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
|
|
size_t ArgSize) {
|
|
ArgTuple Args;
|
|
if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
|
|
return WrapperFunctionResult::createOutOfBandError(
|
|
"Could not deserialize arguments for wrapper function call");
|
|
|
|
auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
|
|
std::forward<HandlerT>(H), Args, ArgIndices{});
|
|
|
|
if (auto Result = ResultSerializer<decltype(HandlerResult)>::serialize(
|
|
std::move(HandlerResult)))
|
|
return std::move(*Result);
|
|
else
|
|
return WrapperFunctionResult::createOutOfBandError(
|
|
toString(Result.takeError()));
|
|
}
|
|
|
|
private:
|
|
template <std::size_t... I>
|
|
static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
|
|
std::index_sequence<I...>) {
|
|
SPSInputBuffer IB(ArgData, ArgSize);
|
|
return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
|
|
}
|
|
|
|
};
|
|
|
|
// Map function references to function types.
|
|
template <typename RetT, typename... ArgTs,
|
|
template <typename> class ResultSerializer, typename... SPSTagTs>
|
|
class WrapperFunctionHandlerHelper<RetT (&)(ArgTs...), ResultSerializer,
|
|
SPSTagTs...>
|
|
: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
|
|
SPSTagTs...> {};
|
|
|
|
// Map non-const member function types to function types.
|
|
template <typename ClassT, typename RetT, typename... ArgTs,
|
|
template <typename> class ResultSerializer, typename... SPSTagTs>
|
|
class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
|
|
SPSTagTs...>
|
|
: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
|
|
SPSTagTs...> {};
|
|
|
|
// Map const member function types to function types.
|
|
template <typename ClassT, typename RetT, typename... ArgTs,
|
|
template <typename> class ResultSerializer, typename... SPSTagTs>
|
|
class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
|
|
ResultSerializer, SPSTagTs...>
|
|
: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
|
|
SPSTagTs...> {};
|
|
|
|
template <typename SPSRetTagT, typename RetT> class ResultSerializer {
|
|
public:
|
|
static Expected<WrapperFunctionResult> serialize(RetT Result) {
|
|
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
|
|
Result);
|
|
}
|
|
};
|
|
|
|
template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
|
|
public:
|
|
static Expected<WrapperFunctionResult> serialize(Error Err) {
|
|
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
|
|
toSPSSerializable(std::move(Err)));
|
|
}
|
|
};
|
|
|
|
template <typename SPSRetTagT, typename T>
|
|
class ResultSerializer<SPSRetTagT, Expected<T>> {
|
|
public:
|
|
static Expected<WrapperFunctionResult> serialize(Expected<T> E) {
|
|
return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
|
|
toSPSSerializable(std::move(E)));
|
|
}
|
|
};
|
|
|
|
template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
|
|
public:
|
|
static void makeSafe(RetT &Result) {}
|
|
|
|
static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
|
|
SPSInputBuffer IB(ArgData, ArgSize);
|
|
if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
|
|
return make_error<StringError>(
|
|
"Error deserializing return value from blob in call");
|
|
return Error::success();
|
|
}
|
|
};
|
|
|
|
template <> class ResultDeserializer<SPSError, Error> {
|
|
public:
|
|
static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
|
|
|
|
static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
|
|
SPSInputBuffer IB(ArgData, ArgSize);
|
|
SPSSerializableError BSE;
|
|
if (!SPSArgList<SPSError>::deserialize(IB, BSE))
|
|
return make_error<StringError>(
|
|
"Error deserializing return value from blob in call");
|
|
Err = fromSPSSerializable(std::move(BSE));
|
|
return Error::success();
|
|
}
|
|
};
|
|
|
|
template <typename SPSTagT, typename T>
|
|
class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
|
|
public:
|
|
static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
|
|
|
|
static Error deserialize(Expected<T> &E, const char *ArgData,
|
|
size_t ArgSize) {
|
|
SPSInputBuffer IB(ArgData, ArgSize);
|
|
SPSSerializableExpected<T> BSE;
|
|
if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
|
|
return make_error<StringError>(
|
|
"Error deserializing return value from blob in call");
|
|
E = fromSPSSerializable(std::move(BSE));
|
|
return Error::success();
|
|
}
|
|
};
|
|
|
|
} // end namespace detail
|
|
|
|
template <typename SPSSignature> class WrapperFunction;
|
|
|
|
template <typename SPSRetTagT, typename... SPSTagTs>
|
|
class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
|
|
private:
|
|
template <typename RetT>
|
|
using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
|
|
|
|
public:
|
|
template <typename RetT, typename... ArgTs>
|
|
static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) {
|
|
|
|
// RetT might be an Error or Expected value. Set the checked flag now:
|
|
// we don't want the user to have to check the unused result if this
|
|
// operation fails.
|
|
detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
|
|
|
|
if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx))
|
|
return make_error<StringError>("__orc_jtjit_dispatch_ctx not set");
|
|
if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch))
|
|
return make_error<StringError>("__orc_jtjit_dispatch not set");
|
|
|
|
auto ArgBuffer =
|
|
detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
|
|
Args...);
|
|
if (!ArgBuffer)
|
|
return ArgBuffer.takeError();
|
|
|
|
WrapperFunctionResult ResultBuffer =
|
|
__orc_rt_jit_dispatch(&__orc_rt_jit_dispatch_ctx, FnTag,
|
|
ArgBuffer->data(), ArgBuffer->size());
|
|
if (auto ErrMsg = ResultBuffer.getOutOfBandError())
|
|
return make_error<StringError>(ErrMsg);
|
|
|
|
return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
|
|
Result, ResultBuffer.data(), ResultBuffer.size());
|
|
}
|
|
|
|
template <typename HandlerT>
|
|
static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
|
|
HandlerT &&Handler) {
|
|
using WFHH =
|
|
detail::WrapperFunctionHandlerHelper<HandlerT, ResultSerializer,
|
|
SPSTagTs...>;
|
|
return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
|
|
}
|
|
|
|
private:
|
|
template <typename T> static const T &makeSerializable(const T &Value) {
|
|
return Value;
|
|
}
|
|
|
|
static detail::SPSSerializableError makeSerializable(Error Err) {
|
|
return detail::toSPSSerializable(std::move(Err));
|
|
}
|
|
|
|
template <typename T>
|
|
static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
|
|
return detail::toSPSSerializable(std::move(E));
|
|
}
|
|
};
|
|
|
|
template <typename... SPSTagTs>
|
|
class WrapperFunction<void(SPSTagTs...)>
|
|
: private WrapperFunction<SPSEmpty(SPSTagTs...)> {
|
|
public:
|
|
template <typename... ArgTs>
|
|
static Error call(const void *FnTag, const ArgTs &...Args) {
|
|
SPSEmpty BE;
|
|
return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...);
|
|
}
|
|
|
|
using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
|
|
};
|
|
|
|
} // end namespace __orc_rt
|
|
|
|
#endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H
|