diff --git a/BUILD.bazel b/BUILD.bazel
index 441f689e3..33cbefd29 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -31,9 +31,6 @@ transitive_py_deps(
         "//jax:experimental",
         "//jax:experimental_colocated_python",
         "//jax:experimental_sparse",
-        "//jax:internal_export_back_compat_test_util",
-        "//jax:internal_test_harnesses",
-        "//jax:internal_test_util",
         "//jax:lax_reference",
         "//jax:pallas_experimental_gpu_ops",
         "//jax:pallas_gpu_ops",
diff --git a/setup.py b/setup.py
index e00a55e1e..80f45285b 100644
--- a/setup.py
+++ b/setup.py
@@ -50,7 +50,7 @@ setup(
     long_description_content_type='text/markdown',
     author='JAX team',
     author_email='jax-dev@google.com',
-    packages=find_packages(exclude=["examples", "jax/src/internal_test_util"]),
+    packages=find_packages(exclude=["*examples*", "*internal_test_util*"]),
     package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]},
     python_requires='>=3.10',
     install_requires=[