diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD.bazel index 342e65ea2..58dfd076d 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD.bazel @@ -60,8 +60,8 @@ cc_library( rocm_library( name = "hip_make_batch_pointers", - srcs = ["//third_party/py/jax/jaxlib/gpu:make_batch_pointers.cu.cc"], - hdrs = ["//third_party/py/jax/jaxlib/gpu:make_batch_pointers.h"], + srcs = ["//jaxlib/gpu:make_batch_pointers.cu.cc"], + hdrs = ["//jaxlib/gpu:make_batch_pointers.h"], deps = [ ":hip_vendor", "@local_config_rocm//rocm:rocm_headers",