Merge pull request #24242 from ROCm:ci_bazel_build

PiperOrigin-RevId: 684802112
This commit is contained in:
jax authors 2024-10-11 05:28:31 -07:00
commit e58ef1af37
2 changed files with 21 additions and 6 deletions

View File

@ -277,6 +277,7 @@ def jax_multiplatform_test(
"//jax:test_util",
] + deps + if_building_jaxlib([
"//jaxlib/cuda:gpu_only_test_deps",
"//jaxlib/rocm:gpu_only_test_deps",
"//jax_plugins:gpu_plugin_only_test_deps",
]),
data = data,

View File

@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/
#include <Python.h>
#include <cstddef>
#include <string>
#include <string_view>
#include <utility>
#include "nanobind/nanobind.h"
@ -34,7 +36,8 @@ namespace nb = nanobind;
namespace xla {
namespace {
absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name,
absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api,
const char* fn_name_c_str, size_t fn_name_size,
nb::object fn, int api_version,
XLA_FFI_Handler_Traits traits) {
if (c_api->extension_start == nullptr) {
@ -59,8 +62,8 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name,
PJRT_Gpu_Register_Custom_Call_Args args;
args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE;
args.function_name = fn_name.c_str();
args.function_name_size = nb::len(fn_name);
args.function_name = fn_name_c_str;
args.function_name_size = fn_name_size;
#if PJRT_API_GPU_EXTENSION_VERSION >= 1
args.api_version = api_version;
@ -179,12 +182,23 @@ NB_MODULE(rocm_plugin_extension, m) {
tsl::ImportNumpy();
m.def(
"register_custom_call_target",
[](nb::capsule c_api, nb::str fn_name, nb::object fn,
[](nb::capsule c_api, nb::object fn_name_py, nb::capsule fn,
nb::str xla_platform_name, int api_version,
XLA_FFI_Handler_Traits traits) {
const char* fn_name_c_str;
size_t fn_name_size;
nb::str fn_name_bn_str;
if (nb::try_cast<nb::str>(fn_name_py, fn_name_bn_str)) {
fn_name_c_str = fn_name_bn_str.c_str();
fn_name_size = nb::len(fn_name_bn_str);
} else{
nb::bytes bytes = nb::cast<nb::bytes>(fn_name_py);
fn_name_c_str = bytes.c_str();
fn_name_size = bytes.size();
}
xla::ThrowIfError(RegisterCustomCallTarget(
static_cast<const PJRT_Api*>(c_api.data()), fn_name, std::move(fn),
api_version, traits));
static_cast<const PJRT_Api*>(c_api.data()), fn_name_c_str,
fn_name_size, std::move(fn), api_version, traits));
},
nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"),
nb::arg("xla_platform_name"), nb::arg("api_version") = 0,