Added an api_version field to PJRT_Gpu_Register_Custom_Call*

This allows using the correct registration API for both legacy (untyped) and
new (typed) XLA FFI custom calls.

PiperOrigin-RevId: 597818106
This commit is contained in:
Sergei Lebedev 2024-01-12 05:46:48 -08:00 committed by jax authors
parent f625fb69da
commit 935db25a2a

View File

@ -33,7 +33,7 @@ namespace nb = nanobind;
namespace xla {
namespace {
Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name,
nb::capsule fn) {
nb::capsule fn, int api_version) {
static const char* const kName = "xla._CUSTOM_CALL_TARGET";
if (std::string_view(fn.name()) != kName) {
return InvalidArgument(
@ -59,6 +59,9 @@ Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name,
args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE;
args.function_name = fn_name.c_str();
args.function_name_size = nb::len(fn_name);
#if PJRT_API_GPU_EXTENSION_VERSION >= 1
args.api_version = api_version;
#endif
args.custom_call_function = static_cast<void*>(fn.data());
RETURN_STATUS_IF_PJRT_ERROR(
reinterpret_cast<const PJRT_Gpu_Custom_Call*>(next)->custom_call(&args),
@ -75,12 +78,16 @@ nb::dict Registrations() {
} // namespace
NB_MODULE(cuda_plugin_extension, m) {
m.def("register_custom_call_target", [](nb::capsule c_api, nb::str fn_name,
nb::capsule fn,
nb::str xla_platform_name) {
xla::ThrowIfError(RegisterCustomCallTarget(
static_cast<const PJRT_Api*>(c_api.data()), fn_name, std::move(fn)));
});
m.def(
"register_custom_call_target",
[](nb::capsule c_api, nb::str fn_name, nb::capsule fn,
nb::str xla_platform_name, int api_version) {
xla::ThrowIfError(
RegisterCustomCallTarget(static_cast<const PJRT_Api*>(c_api.data()),
fn_name, std::move(fn), api_version));
},
nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"),
nb::arg("xla_platform_name"), nb::arg("api_version") = 1);
m.def("registrations", &Registrations);
}
} // namespace xla