[Offload] Make only a single thread handle the RPC server thread (#126067)

Summary:
This patch just changes the interface to make starting the thread
multiple times permissable since it will only be done the first time.
Note that this does not refcount it or anything, so it's onto the user
to make sure that they don't shut down the thread before everyone is
done using it. That is the case today because the shutDown portion is
run by a single thread in the destructor phase.

Another question is if we should make this thread truly global state,
because currently it will be private to each plugin instance, so if you
have an AMD and NVIDIA image there will be two, similarly if you have
those inside of a shared library.
This commit is contained in:
Joseph Huber 2025-02-06 11:38:14 -06:00 committed by GitHub
parent 11c3f52bbb
commit 5812d0bf8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 14 deletions

View File

@ -80,7 +80,7 @@ private:
std::thread Worker;
/// A boolean indicating whether or not the worker thread should continue.
std::atomic<bool> Running;
std::atomic<uint32_t> Running;
/// The number of currently executing kernels across all devices that need
/// the server thread to be running.

View File

@ -1058,9 +1058,8 @@ Error GenericDeviceTy::setupRPCServer(GenericPluginTy &Plugin,
if (auto Err = Server.initDevice(*this, Plugin.getGlobalHandler(), Image))
return Err;
if (!Server.Thread->Running.load(std::memory_order_acquire))
if (auto Err = Server.startThread())
return Err;
if (auto Err = Server.startThread())
return Err;
RPCServer = &Server;
DP("Running an RPC server on device %d\n", getDeviceId());
@ -1635,12 +1634,11 @@ Error GenericPluginTy::deinit() {
if (GlobalHandler)
delete GlobalHandler;
if (RPCServer && RPCServer->Thread->Running.load(std::memory_order_acquire))
if (RPCServer) {
if (Error Err = RPCServer->shutDown())
return Err;
if (RPCServer)
delete RPCServer;
}
if (RecordReplay)
delete RecordReplay;

View File

@ -99,18 +99,15 @@ static rpc::Status runServer(plugin::GenericDeviceTy &Device, void *Buffer) {
}
void RPCServerTy::ServerThread::startThread() {
assert(!Running.load(std::memory_order_relaxed) &&
"Attempting to start thread that is already running");
Running.store(true, std::memory_order_release);
Worker = std::thread([this]() { run(); });
if (!Running.fetch_or(true, std::memory_order_acquire))
Worker = std::thread([this]() { run(); });
}
void RPCServerTy::ServerThread::shutDown() {
assert(Running.load(std::memory_order_relaxed) &&
"Attempting to shut down a thread that is not running");
if (!Running.fetch_and(false, std::memory_order_release))
return;
{
std::lock_guard<decltype(Mutex)> Lock(Mutex);
Running.store(false, std::memory_order_release);
CV.notify_all();
}
if (Worker.joinable())