[mosaic_gpu] Check the return code of gpuEventCreate and gpuEventDestroy

PiperOrigin-RevId: 693260326
This commit is contained in:
Sergei Lebedev 2024-11-05 01:59:21 -08:00 committed by jax authors
parent 63e59c5fd7
commit 34b4787e2e
2 changed files with 33 additions and 10 deletions

View File

@ -184,6 +184,7 @@ pybind_extension(
deps = [
"//jaxlib:kernel_nanobind_helpers",
"//jaxlib/cuda:cuda_vendor",
"@com_google_absl//absl/strings",
"@nanobind",
"@xla//xla/service:custom_call_status",
"@xla//xla/tsl/cuda:cudart",

View File

@ -14,8 +14,11 @@ limitations under the License.
==============================================================================*/
#include <cstdint>
#include <stdexcept>
#include <string>
#include "nanobind/nanobind.h"
#include "absl/strings/str_cat.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_nanobind_helpers.h"
#include "xla/service/custom_call_status.h"
@ -23,33 +26,52 @@ limitations under the License.
namespace jax::cuda {
namespace {
namespace nb = nanobind;
static std::string ToString(CUresult result) {
const char* error_name;
if (cuGetErrorName(result, &error_name)) {
return absl::StrCat("UNKNOWN ERROR (", static_cast<int>(result), ")");
}
const char* error_string;
if (cuGetErrorString(result, &error_string)) {
return error_name;
}
return absl::StrCat(error_name, ": ", error_string);
}
void EventRecordCall(void* stream, void** buffers, char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto* event = reinterpret_cast<gpuEvent_t**>(opaque);
if (gpuEventRecord(**event, reinterpret_cast<gpuStream_t>(stream)) !=
gpuSuccess) {
const char message[] = "Failed to record event";
XlaCustomCallStatusSetFailure(status, message, sizeof(message));
if (auto res = gpuEventRecord(**event, reinterpret_cast<gpuStream_t>(stream));
res) {
auto message = absl::StrCat("Failed to record event: ", ToString(res));
XlaCustomCallStatusSetFailure(status, message.c_str(), message.size());
}
}
NB_MODULE(_mosaic_gpu_ext, m) {
m.def("_gpu_event_create", []() {
gpuEvent_t* event = new gpuEvent_t();
gpuEventCreate(event, GPU_EVENT_DEFAULT);
if (auto res = gpuEventCreate(event, GPU_EVENT_DEFAULT); res) {
throw std::runtime_error(
absl::StrCat("Failed to create event: ", ToString(res)));
}
return reinterpret_cast<uintptr_t>(event);
});
m.def("_gpu_event_destroy", [](uintptr_t event) {
gpuEventDestroy(*reinterpret_cast<gpuEvent_t*>(event));
if (auto res = gpuEventDestroy(*reinterpret_cast<gpuEvent_t*>(event));
res) {
throw std::runtime_error(
absl::StrCat("Failed to destroy event: ", ToString(res)));
}
});
m.def("_gpu_event_elapsed", [](uintptr_t start_event, uintptr_t end_event) {
float elapsed_ms = -1;
if (gpuEventElapsedTime(
if (auto res = gpuEventElapsedTime(
&elapsed_ms, *reinterpret_cast<gpuEvent_t*>(start_event),
*reinterpret_cast<gpuEvent_t*>(end_event)) != gpuSuccess) {
throw std::runtime_error("Failed to get elapsed time between events");
*reinterpret_cast<gpuEvent_t*>(end_event));
res) {
throw std::runtime_error(absl::StrCat(
"Failed to get elapsed time between events: ", ToString(res)));
}
return elapsed_ms;
});