mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Fixed the default api_version= in register_custom_call_target()
PiperOrigin-RevId: 597834961
This commit is contained in:
parent
935db25a2a
commit
87301aa737
@ -87,7 +87,7 @@ NB_MODULE(cuda_plugin_extension, m) {
|
||||
fn_name, std::move(fn), api_version));
|
||||
},
|
||||
nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"),
|
||||
nb::arg("xla_platform_name"), nb::arg("api_version") = 1);
|
||||
nb::arg("xla_platform_name"), nb::arg("api_version") = 0);
|
||||
m.def("registrations", &Registrations);
|
||||
}
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user