mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Fixed the default api_version= in register_custom_call_target()
PiperOrigin-RevId: 597834961
This commit is contained in:
parent
935db25a2a
commit
87301aa737
@ -87,7 +87,7 @@ NB_MODULE(cuda_plugin_extension, m) {
|
|||||||
fn_name, std::move(fn), api_version));
|
fn_name, std::move(fn), api_version));
|
||||||
},
|
},
|
||||||
nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"),
|
nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"),
|
||||||
nb::arg("xla_platform_name"), nb::arg("api_version") = 1);
|
nb::arg("xla_platform_name"), nb::arg("api_version") = 0);
|
||||||
m.def("registrations", &Registrations);
|
m.def("registrations", &Registrations);
|
||||||
}
|
}
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
x
Reference in New Issue
Block a user