Fixed the default api_version= in register_custom_call_target()

PiperOrigin-RevId: 597834961
This commit is contained in:
Sergei Lebedev 2024-01-12 07:26:17 -08:00 committed by jax authors
parent 935db25a2a
commit 87301aa737

View File

@ -87,7 +87,7 @@ NB_MODULE(cuda_plugin_extension, m) {
fn_name, std::move(fn), api_version));
},
nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"),
nb::arg("xla_platform_name"), nb::arg("api_version") = 1);
nb::arg("xla_platform_name"), nb::arg("api_version") = 0);
m.def("registrations", &Registrations);
}
} // namespace xla