diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index 440e4d1a7..b7abb136a 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -23,12 +23,9 @@ cc_binary( deps = [ "@xla//xla:literal", "@xla//xla:literal_util", - "@xla//xla:shape_util", - "@xla//xla:status", "@xla//xla:statusor", "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt/cpu:cpu_client", - "@xla//xla/service:hlo_proto_cc", "@xla//xla/tools:hlo_module_loader", "@tsl//tsl/platform:logging", "@tsl//tsl/platform:platform_port", diff --git a/examples/jax_cpp/main.cc b/examples/jax_cpp/main.cc index 26ebb1ecd..42ceb6f51 100644 --- a/examples/jax_cpp/main.cc +++ b/examples/jax_cpp/main.cc @@ -44,7 +44,6 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/status.h" #include "xla/statusor.h" #include "xla/tools/hlo_module_loader.h" #include "tsl/platform/init_main.h" diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 2c59acdcf..a9d06df6f 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -187,6 +187,7 @@ pybind_extension( srcs = ["cuda_plugin_extension.cc"], module_name = "cuda_plugin_extension", deps = [ + "@com_google_absl//absl/status", "@nanobind", "//jaxlib:kernel_nanobind_helpers", "@xla//third_party/python_runtime:headers", diff --git a/jaxlib/cuda_plugin_extension.cc b/jaxlib/cuda_plugin_extension.cc index bffbcbd98..28007de73 100644 --- a/jaxlib/cuda_plugin_extension.cc +++ b/jaxlib/cuda_plugin_extension.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "nanobind/nanobind.h" +#include "absl/status/status.h" #include "third_party/gpus/cuda/include/cuda.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/pjrt/c/pjrt_c_api.h" @@ -26,7 +27,6 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/status_casters.h" #include "xla/python/py_client_gpu.h" -#include "xla/status.h" #include "xla/tsl/python/lib/core/numpy.h" #include "xla/util.h"