[Offload][OMPT] Add callbacks for (dis)associate_ptr (#99046)

This adds the OMPT callbacks for the API functions disassociate_ptr and
associate_ptr.
This commit is contained in:
Jan Patrick Lehr 2024-07-17 10:15:19 +02:00 committed by GitHub
parent 0905732f75
commit caaf8099ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 136 additions and 0 deletions

View File

@ -109,6 +109,25 @@ public:
/// Top-level function for invoking callback after target update construct
void endTargetUpdate(int64_t DeviceId, void *Code);
/// Top-level function for invoking callback before target associate API
void beginTargetAssociatePointer(int64_t DeviceId, void *HstPtrBegin,
void *TgtPtrBegin, size_t Size, void *Code);
/// Top-level function for invoking callback after target associate API
void endTargetAssociatePointer(int64_t DeviceId, void *HstPtrBegin,
void *TgtPtrBegin, size_t Size, void *Code);
/// Top-level function for invoking callback before target disassociate API
void beginTargetDisassociatePointer(int64_t DeviceId, void *HstPtrBegin,
void *TgtPtrBegin, size_t Size,
void *Code);
/// Top-level function for invoking callback after target disassociate API
void endTargetDisassociatePointer(int64_t DeviceId, void *HstPtrBegin,
void *TgtPtrBegin, size_t Size, void *Code);
// Target kernel callbacks
/// Top-level function for invoking callback before target construct
void beginTarget(int64_t DeviceId, void *Code);
@ -137,6 +156,16 @@ public:
return std::make_pair(std::mem_fn(&Interface::beginTargetDataRetrieve),
std::mem_fn(&Interface::endTargetDataRetrieve));
if constexpr (OpType == ompt_target_data_associate)
return std::make_pair(
std::mem_fn(&Interface::beginTargetAssociatePointer),
std::mem_fn(&Interface::endTargetAssociatePointer));
if constexpr (OpType == ompt_target_data_disassociate)
return std::make_pair(
std::mem_fn(&Interface::beginTargetDisassociatePointer),
std::mem_fn(&Interface::endTargetDisassociatePointer));
llvm_unreachable("Unhandled target data operation type!");
}

View File

@ -597,6 +597,12 @@ EXTERN int omp_target_associate_ptr(const void *HostPtr, const void *DevicePtr,
FATAL_MESSAGE(DeviceNum, "%s", toString(DeviceOrErr.takeError()).c_str());
void *DeviceAddr = (void *)((uint64_t)DevicePtr + (uint64_t)DeviceOffset);
OMPT_IF_BUILT(InterfaceRAII(
RegionInterface.getCallbacks<ompt_target_data_associate>(), DeviceNum,
const_cast<void *>(HostPtr), const_cast<void *>(DevicePtr), Size,
__builtin_return_address(0)));
int Rc = DeviceOrErr->getMappingInfo().associatePtr(
const_cast<void *>(HostPtr), const_cast<void *>(DeviceAddr), Size);
DP("omp_target_associate_ptr returns %d\n", Rc);
@ -625,6 +631,11 @@ EXTERN int omp_target_disassociate_ptr(const void *HostPtr, int DeviceNum) {
if (!DeviceOrErr)
FATAL_MESSAGE(DeviceNum, "%s", toString(DeviceOrErr.takeError()).c_str());
OMPT_IF_BUILT(InterfaceRAII(
RegionInterface.getCallbacks<ompt_target_data_disassociate>(), DeviceNum,
const_cast<void *>(HostPtr),
/*DevicePtr=*/nullptr, /*Size=*/0, __builtin_return_address(0)));
int Rc = DeviceOrErr->getMappingInfo().disassociatePtr(
const_cast<void *>(HostPtr));
DP("omp_target_disassociate_ptr returns %d\n", Rc);

View File

@ -332,6 +332,63 @@ void Interface::endTargetUpdate(int64_t DeviceId, void *Code) {
endTargetRegion();
}
void Interface::beginTargetAssociatePointer(int64_t DeviceId, void *HstPtrBegin,
void *TgtPtrBegin, size_t Size,
void *Code) {
beginTargetDataOperation();
if (ompt_callback_target_data_op_emi_fn) {
ompt_callback_target_data_op_emi_fn(
ompt_scope_begin, TargetTaskData, &TargetData, &HostOpId,
ompt_target_data_associate, HstPtrBegin, omp_get_initial_device(),
TgtPtrBegin, DeviceId, Size, Code);
} else if (ompt_callback_target_data_op_fn) {
HostOpId = createOpId();
ompt_callback_target_data_op_fn(
TargetData.value, HostOpId, ompt_target_data_associate, HstPtrBegin,
omp_get_initial_device(), TgtPtrBegin, DeviceId, Size, Code);
}
}
void Interface::endTargetAssociatePointer(int64_t DeviceId, void *HstPtrBegin,
void *TgtPtrBegin, size_t Size,
void *Code) {
if (ompt_callback_target_data_op_emi_fn) {
ompt_callback_target_data_op_emi_fn(
ompt_scope_end, TargetTaskData, &TargetData, &HostOpId,
ompt_target_data_associate, HstPtrBegin, omp_get_initial_device(),
TgtPtrBegin, DeviceId, Size, Code);
}
}
void Interface::beginTargetDisassociatePointer(int64_t DeviceId,
void *HstPtrBegin,
void *TgtPtrBegin, size_t Size,
void *Code) {
beginTargetDataOperation();
if (ompt_callback_target_data_op_emi_fn) {
ompt_callback_target_data_op_emi_fn(
ompt_scope_begin, TargetTaskData, &TargetData, &HostOpId,
ompt_target_data_disassociate, HstPtrBegin, omp_get_initial_device(),
TgtPtrBegin, DeviceId, Size, Code);
} else if (ompt_callback_target_data_op_fn) {
HostOpId = createOpId();
ompt_callback_target_data_op_fn(
TargetData.value, HostOpId, ompt_target_data_disassociate, HstPtrBegin,
omp_get_initial_device(), TgtPtrBegin, DeviceId, Size, Code);
}
}
void Interface::endTargetDisassociatePointer(int64_t DeviceId,
void *HstPtrBegin,
void *TgtPtrBegin, size_t Size,
void *Code) {
if (ompt_callback_target_data_op_emi_fn) {
ompt_callback_target_data_op_emi_fn(
ompt_scope_end, TargetTaskData, &TargetData, &HostOpId,
ompt_target_data_disassociate, HstPtrBegin, omp_get_initial_device(),
TgtPtrBegin, DeviceId, Size, Code);
}
}
void Interface::beginTarget(int64_t DeviceId, void *Code) {
beginTargetRegion();
if (ompt_callback_target_emi_fn) {

View File

@ -0,0 +1,39 @@
// RUN: %libomptarget-compile-run-and-check-generic
// REQUIRES: ompt
// REQUIRES: gpu
#include "omp.h"
#include <stdlib.h>
#include <string.h>
#include "callbacks.h"
#include "register_non_emi.h"
#define N 1024
int main(int argc, char **argv) {
int *h_a;
int *d_a;
h_a = (int *)malloc(N * sizeof(int));
memset(h_a, 0, N);
d_a = (int *)omp_target_alloc(N * sizeof(int), omp_get_default_device());
omp_target_associate_ptr(h_a, d_a, N * sizeof(int), 0,
omp_get_default_device());
omp_target_disassociate_ptr(h_a, omp_get_default_device());
omp_target_free(d_a, omp_get_default_device());
free(h_a);
return 0;
}
// clang-format off
/// CHECK: Callback Init:
/// CHECK: Callback DataOp: target_id=[[TARGET_ID:[0-9]+]] host_op_id=[[HOST_OP_ID:[0-9]+]] optype=1
/// CHECK: Callback DataOp: target_id=[[TARGET_ID:[0-9]+]] host_op_id=[[HOST_OP_ID:[0-9]+]] optype=5
/// CHECK: Callback DataOp: target_id=[[TARGET_ID:[0-9]+]] host_op_id=[[HOST_OP_ID:[0-9]+]] optype=6
/// CHECK: Callback DataOp: target_id=[[TARGET_ID:[0-9]+]] host_op_id=[[HOST_OP_ID:[0-9]+]] optype=4
/// CHECK: Callback Fini: