mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Added an api_version field to PJRT_Gpu_Register_Custom_Call*
This allows using the correct registration API for both legacy (untyped) and new (typed) XLA FFI custom calls. PiperOrigin-RevId: 597818106
This commit is contained in:
parent
f625fb69da
commit
935db25a2a
@ -33,7 +33,7 @@ namespace nb = nanobind;
|
||||
namespace xla {
|
||||
namespace {
|
||||
Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name,
|
||||
nb::capsule fn) {
|
||||
nb::capsule fn, int api_version) {
|
||||
static const char* const kName = "xla._CUSTOM_CALL_TARGET";
|
||||
if (std::string_view(fn.name()) != kName) {
|
||||
return InvalidArgument(
|
||||
@ -59,6 +59,9 @@ Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name,
|
||||
args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE;
|
||||
args.function_name = fn_name.c_str();
|
||||
args.function_name_size = nb::len(fn_name);
|
||||
#if PJRT_API_GPU_EXTENSION_VERSION >= 1
|
||||
args.api_version = api_version;
|
||||
#endif
|
||||
args.custom_call_function = static_cast<void*>(fn.data());
|
||||
RETURN_STATUS_IF_PJRT_ERROR(
|
||||
reinterpret_cast<const PJRT_Gpu_Custom_Call*>(next)->custom_call(&args),
|
||||
@ -75,12 +78,16 @@ nb::dict Registrations() {
|
||||
} // namespace
|
||||
|
||||
NB_MODULE(cuda_plugin_extension, m) {
|
||||
m.def("register_custom_call_target", [](nb::capsule c_api, nb::str fn_name,
|
||||
nb::capsule fn,
|
||||
nb::str xla_platform_name) {
|
||||
xla::ThrowIfError(RegisterCustomCallTarget(
|
||||
static_cast<const PJRT_Api*>(c_api.data()), fn_name, std::move(fn)));
|
||||
});
|
||||
m.def(
|
||||
"register_custom_call_target",
|
||||
[](nb::capsule c_api, nb::str fn_name, nb::capsule fn,
|
||||
nb::str xla_platform_name, int api_version) {
|
||||
xla::ThrowIfError(
|
||||
RegisterCustomCallTarget(static_cast<const PJRT_Api*>(c_api.data()),
|
||||
fn_name, std::move(fn), api_version));
|
||||
},
|
||||
nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"),
|
||||
nb::arg("xla_platform_name"), nb::arg("api_version") = 1);
|
||||
m.def("registrations", &Registrations);
|
||||
}
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user