diff --git a/offload/plugins-nextgen/common/include/RPC.h b/offload/plugins-nextgen/common/include/RPC.h index 42fca4aa4aeb..08556f15a76b 100644 --- a/offload/plugins-nextgen/common/include/RPC.h +++ b/offload/plugins-nextgen/common/include/RPC.h @@ -72,6 +72,9 @@ private: /// Array of associated devices. These must be alive as long as the server is. std::unique_ptr Devices; + /// Mutex that guards accesses to the buffers and device array. + std::mutex BufferMutex{}; + /// A helper class for running the user thread that handles the RPC interface. /// Because we only need to check the RPC server while any kernels are /// working, we track submission / completion events to allow the thread to @@ -90,6 +93,9 @@ private: std::condition_variable CV; std::mutex Mutex; + /// A reference to the main server's mutex. + std::mutex &BufferMutex; + /// A reference to all the RPC interfaces that the server is handling. llvm::ArrayRef Buffers; @@ -98,9 +104,9 @@ private: /// Initialize the worker thread to run in the background. ServerThread(void *Buffers[], plugin::GenericDeviceTy *Devices[], - size_t Length) - : Running(false), NumUsers(0), CV(), Mutex(), Buffers(Buffers, Length), - Devices(Devices, Length) {} + size_t Length, std::mutex &BufferMutex) + : Running(false), NumUsers(0), CV(), Mutex(), BufferMutex(BufferMutex), + Buffers(Buffers, Length), Devices(Devices, Length) {} ~ServerThread() { assert(!Running && "Thread not shut down explicitly\n"); } diff --git a/offload/plugins-nextgen/common/src/RPC.cpp b/offload/plugins-nextgen/common/src/RPC.cpp index e6750a540b39..eb305736d626 100644 --- a/offload/plugins-nextgen/common/src/RPC.cpp +++ b/offload/plugins-nextgen/common/src/RPC.cpp @@ -131,6 +131,7 @@ void RPCServerTy::ServerThread::run() { Lock.unlock(); while (NumUsers.load(std::memory_order_relaxed) > 0 && Running.load(std::memory_order_relaxed)) { + std::lock_guard Lock(BufferMutex); for (const auto &[Buffer, Device] : llvm::zip_equal(Buffers, Devices)) { if (!Buffer || !Device) continue; @@ -149,7 +150,7 @@ RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin) Devices(std::make_unique( Plugin.getNumDevices())), Thread(new ServerThread(Buffers.get(), Devices.get(), - Plugin.getNumDevices())) {} + Plugin.getNumDevices(), BufferMutex)) {} llvm::Error RPCServerTy::startThread() { Thread->startThread(); @@ -190,6 +191,7 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device, if (auto Err = Device.dataSubmit(ClientGlobal.getPtr(), &client, sizeof(rpc::Client), nullptr)) return Err; + std::lock_guard Lock(BufferMutex); Buffers[Device.getDeviceId()] = RPCBuffer; Devices[Device.getDeviceId()] = &Device; @@ -197,6 +199,7 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device, } Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) { + std::lock_guard Lock(BufferMutex); Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST); Buffers[Device.getDeviceId()] = nullptr; Devices[Device.getDeviceId()] = nullptr;