mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[mosaic_gpu] Check the return code of gpuEventCreate
and gpuEventDestroy
PiperOrigin-RevId: 693260326
This commit is contained in:
parent
63e59c5fd7
commit
34b4787e2e
@ -184,6 +184,7 @@ pybind_extension(
|
||||
deps = [
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"//jaxlib/cuda:cuda_vendor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@nanobind",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
|
@ -14,8 +14,11 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_nanobind_helpers.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
@ -23,33 +26,52 @@ limitations under the License.
|
||||
namespace jax::cuda {
|
||||
namespace {
|
||||
|
||||
namespace nb = nanobind;
|
||||
static std::string ToString(CUresult result) {
|
||||
const char* error_name;
|
||||
if (cuGetErrorName(result, &error_name)) {
|
||||
return absl::StrCat("UNKNOWN ERROR (", static_cast<int>(result), ")");
|
||||
}
|
||||
const char* error_string;
|
||||
if (cuGetErrorString(result, &error_string)) {
|
||||
return error_name;
|
||||
}
|
||||
return absl::StrCat(error_name, ": ", error_string);
|
||||
}
|
||||
|
||||
void EventRecordCall(void* stream, void** buffers, char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto* event = reinterpret_cast<gpuEvent_t**>(opaque);
|
||||
if (gpuEventRecord(**event, reinterpret_cast<gpuStream_t>(stream)) !=
|
||||
gpuSuccess) {
|
||||
const char message[] = "Failed to record event";
|
||||
XlaCustomCallStatusSetFailure(status, message, sizeof(message));
|
||||
if (auto res = gpuEventRecord(**event, reinterpret_cast<gpuStream_t>(stream));
|
||||
res) {
|
||||
auto message = absl::StrCat("Failed to record event: ", ToString(res));
|
||||
XlaCustomCallStatusSetFailure(status, message.c_str(), message.size());
|
||||
}
|
||||
}
|
||||
|
||||
NB_MODULE(_mosaic_gpu_ext, m) {
|
||||
m.def("_gpu_event_create", []() {
|
||||
gpuEvent_t* event = new gpuEvent_t();
|
||||
gpuEventCreate(event, GPU_EVENT_DEFAULT);
|
||||
if (auto res = gpuEventCreate(event, GPU_EVENT_DEFAULT); res) {
|
||||
throw std::runtime_error(
|
||||
absl::StrCat("Failed to create event: ", ToString(res)));
|
||||
}
|
||||
return reinterpret_cast<uintptr_t>(event);
|
||||
});
|
||||
m.def("_gpu_event_destroy", [](uintptr_t event) {
|
||||
gpuEventDestroy(*reinterpret_cast<gpuEvent_t*>(event));
|
||||
if (auto res = gpuEventDestroy(*reinterpret_cast<gpuEvent_t*>(event));
|
||||
res) {
|
||||
throw std::runtime_error(
|
||||
absl::StrCat("Failed to destroy event: ", ToString(res)));
|
||||
}
|
||||
});
|
||||
m.def("_gpu_event_elapsed", [](uintptr_t start_event, uintptr_t end_event) {
|
||||
float elapsed_ms = -1;
|
||||
if (gpuEventElapsedTime(
|
||||
if (auto res = gpuEventElapsedTime(
|
||||
&elapsed_ms, *reinterpret_cast<gpuEvent_t*>(start_event),
|
||||
*reinterpret_cast<gpuEvent_t*>(end_event)) != gpuSuccess) {
|
||||
throw std::runtime_error("Failed to get elapsed time between events");
|
||||
*reinterpret_cast<gpuEvent_t*>(end_event));
|
||||
res) {
|
||||
throw std::runtime_error(absl::StrCat(
|
||||
"Failed to get elapsed time between events: ", ToString(res)));
|
||||
}
|
||||
return elapsed_ms;
|
||||
});
|
||||
|
Loading…
x
Reference in New Issue
Block a user