mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #24242 from ROCm:ci_bazel_build
PiperOrigin-RevId: 684802112
This commit is contained in:
commit
e58ef1af37
@ -277,6 +277,7 @@ def jax_multiplatform_test(
|
||||
"//jax:test_util",
|
||||
] + deps + if_building_jaxlib([
|
||||
"//jaxlib/cuda:gpu_only_test_deps",
|
||||
"//jaxlib/rocm:gpu_only_test_deps",
|
||||
"//jax_plugins:gpu_plugin_only_test_deps",
|
||||
]),
|
||||
data = data,
|
||||
|
@ -14,7 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <Python.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <utility>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
@ -34,7 +36,8 @@ namespace nb = nanobind;
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name,
|
||||
absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api,
|
||||
const char* fn_name_c_str, size_t fn_name_size,
|
||||
nb::object fn, int api_version,
|
||||
XLA_FFI_Handler_Traits traits) {
|
||||
if (c_api->extension_start == nullptr) {
|
||||
@ -59,8 +62,8 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name,
|
||||
|
||||
PJRT_Gpu_Register_Custom_Call_Args args;
|
||||
args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE;
|
||||
args.function_name = fn_name.c_str();
|
||||
args.function_name_size = nb::len(fn_name);
|
||||
args.function_name = fn_name_c_str;
|
||||
args.function_name_size = fn_name_size;
|
||||
|
||||
#if PJRT_API_GPU_EXTENSION_VERSION >= 1
|
||||
args.api_version = api_version;
|
||||
@ -179,12 +182,23 @@ NB_MODULE(rocm_plugin_extension, m) {
|
||||
tsl::ImportNumpy();
|
||||
m.def(
|
||||
"register_custom_call_target",
|
||||
[](nb::capsule c_api, nb::str fn_name, nb::object fn,
|
||||
[](nb::capsule c_api, nb::object fn_name_py, nb::capsule fn,
|
||||
nb::str xla_platform_name, int api_version,
|
||||
XLA_FFI_Handler_Traits traits) {
|
||||
const char* fn_name_c_str;
|
||||
size_t fn_name_size;
|
||||
nb::str fn_name_bn_str;
|
||||
if (nb::try_cast<nb::str>(fn_name_py, fn_name_bn_str)) {
|
||||
fn_name_c_str = fn_name_bn_str.c_str();
|
||||
fn_name_size = nb::len(fn_name_bn_str);
|
||||
} else{
|
||||
nb::bytes bytes = nb::cast<nb::bytes>(fn_name_py);
|
||||
fn_name_c_str = bytes.c_str();
|
||||
fn_name_size = bytes.size();
|
||||
}
|
||||
xla::ThrowIfError(RegisterCustomCallTarget(
|
||||
static_cast<const PJRT_Api*>(c_api.data()), fn_name, std::move(fn),
|
||||
api_version, traits));
|
||||
static_cast<const PJRT_Api*>(c_api.data()), fn_name_c_str,
|
||||
fn_name_size, std::move(fn), api_version, traits));
|
||||
},
|
||||
nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"),
|
||||
nb::arg("xla_platform_name"), nb::arg("api_version") = 0,
|
||||
|
Loading…
x
Reference in New Issue
Block a user