From eb912ad0d9430f5057c192afa0c3cf60a5b84926 Mon Sep 17 00:00:00 2001
From: jax authors <google-ml-automation@google.com>
Date: Tue, 25 Feb 2025 09:28:35 -0800
Subject: [PATCH] Create `jax` wheel build target.

This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)

Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the `jax` wheel with `python3 -m build` command.

Bazel `jax` wheel target: `//:jax_wheel`

Environment variables combinations for creating wheels with different versions:
  * self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * release: `--repo_env=ML_WHEEL_TYPE=release`
  * release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
  * nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 730916743
---
 BUILD                                         |  86 +++++++++++++++
 build/requirements_lock_3_10.txt              |  23 ++++
 build/requirements_lock_3_11.txt              |  23 ++++
 build/requirements_lock_3_12.txt              |  23 ++++
 build/requirements_lock_3_13.txt              |  23 ++++
 build/requirements_lock_3_13_ft.txt           |  23 ++++
 build/test-requirements.txt                   |   1 +
 build_wheel.py                                | 100 ++++++++++++++++++
 jax/BUILD                                     |   9 +-
 jax/_src/pallas/mosaic/BUILD                  |   5 +-
 jax/_src/pallas/triton/BUILD                  |   5 +-
 jax/experimental/array_serialization/BUILD    |  41 +++++++
 jax/extend/mlir/BUILD                         |   9 ++
 jax/extend/mlir/dialects/BUILD                |  18 ++++
 jax/tools/BUILD                               |  19 +++-
 jaxlib/jax.bzl                                |  53 ++++++++--
 tests/BUILD                                   |  12 +++
 .../serialization_test.py                     |   0
 18 files changed, 453 insertions(+), 20 deletions(-)
 create mode 100644 BUILD
 create mode 100644 build_wheel.py
 create mode 100644 jax/experimental/array_serialization/BUILD
 rename {jax/experimental/array_serialization => tests}/serialization_test.py (100%)

diff --git a/BUILD b/BUILD
new file mode 100644
index 000000000..ec6c87166
--- /dev/null
+++ b/BUILD
@@ -0,0 +1,86 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@tsl//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps")
+load(
+    "//jaxlib:jax.bzl",
+    "jax_wheel",
+)
+
+collect_data_files(
+    name = "transitive_py_data",
+    deps = ["//jax"],
+)
+
+transitive_py_deps(
+    name = "transitive_py_deps",
+    deps = [
+        "//jax",
+        "//jax:compilation_cache",
+        "//jax:experimental",
+        "//jax:experimental_colocated_python",
+        "//jax:experimental_sparse",
+        "//jax:internal_export_back_compat_test_util",
+        "//jax:internal_test_harnesses",
+        "//jax:internal_test_util",
+        "//jax:lax_reference",
+        "//jax:pallas_experimental_gpu_ops",
+        "//jax:pallas_gpu_ops",
+        "//jax:pallas_mosaic_gpu",
+        "//jax:pallas_tpu_ops",
+        "//jax:pallas_triton",
+        "//jax:source_mapper",
+        "//jax:sparse_test_util",
+        "//jax:test_util",
+        "//jax/_src/lib",
+        "//jax/_src/pallas/mosaic_gpu",
+        "//jax/experimental/array_serialization:serialization",
+        "//jax/experimental/jax2tf",
+        "//jax/extend",
+        "//jax/extend:ifrt_programs",
+        "//jax/extend/mlir",
+        "//jax/extend/mlir/dialects",
+        "//jax/tools:colab_tpu",
+        "//jax/tools:jax_to_ir",
+        "//jax/tools:pgo_nsys_converter",
+    ],
+)
+
+py_binary(
+    name = "build_wheel",
+    srcs = ["build_wheel.py"],
+    deps = [
+        "//jaxlib/tools:build_utils",
+        "@pypi_build//:pkg",
+        "@pypi_setuptools//:pkg",
+        "@pypi_wheel//:pkg",
+    ],
+)
+
+jax_wheel(
+    name = "jax_wheel",
+    platform_independent = True,
+    source_files = [
+        ":transitive_py_data",
+        ":transitive_py_deps",
+        "//jax:py.typed",
+        "AUTHORS",
+        "LICENSE",
+        "README.md",
+        "pyproject.toml",
+        "setup.py",
+    ],
+    wheel_binary = ":build_wheel",
+    wheel_name = "jax",
+)
diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt
index ccffa247f..c346daea3 100644
--- a/build/requirements_lock_3_10.txt
+++ b/build/requirements_lock_3_10.txt
@@ -566,6 +566,29 @@ sortedcontainers==2.4.0 \
     --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
     --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
     # via hypothesis
+tensorstore==0.1.72 \
+    --hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
+    --hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
+    --hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
+    --hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
+    --hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
+    --hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
+    --hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
+    --hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
+    --hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
+    --hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
+    --hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
+    --hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
+    --hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
+    --hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
+    --hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
+    --hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
+    --hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
+    --hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
+    --hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
+    --hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
+    --hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
+    # via -r build/test-requirements.txt
 tomli==2.0.1 \
     --hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \
     --hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f
diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt
index 7f3ee61ff..faa641940 100644
--- a/build/requirements_lock_3_11.txt
+++ b/build/requirements_lock_3_11.txt
@@ -561,6 +561,29 @@ sortedcontainers==2.4.0 \
     --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
     --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
     # via hypothesis
+tensorstore==0.1.72 \
+    --hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
+    --hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
+    --hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
+    --hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
+    --hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
+    --hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
+    --hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
+    --hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
+    --hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
+    --hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
+    --hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
+    --hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
+    --hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
+    --hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
+    --hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
+    --hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
+    --hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
+    --hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
+    --hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
+    --hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
+    --hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
+    # via -r build/test-requirements.txt
 typing-extensions==4.12.0rc1 \
     --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \
     --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe
diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt
index bf22c3623..d1f8d3ad5 100644
--- a/build/requirements_lock_3_12.txt
+++ b/build/requirements_lock_3_12.txt
@@ -561,6 +561,29 @@ sortedcontainers==2.4.0 \
     --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
     --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
     # via hypothesis
+tensorstore==0.1.72 \
+    --hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
+    --hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
+    --hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
+    --hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
+    --hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
+    --hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
+    --hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
+    --hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
+    --hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
+    --hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
+    --hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
+    --hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
+    --hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
+    --hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
+    --hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
+    --hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
+    --hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
+    --hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
+    --hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
+    --hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
+    --hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
+    # via -r build/test-requirements.txt
 typing-extensions==4.12.0rc1 \
     --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \
     --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe
diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt
index 9fa78c062..4e551eeda 100644
--- a/build/requirements_lock_3_13.txt
+++ b/build/requirements_lock_3_13.txt
@@ -634,6 +634,29 @@ sortedcontainers==2.4.0 \
     --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
     --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
     # via hypothesis
+tensorstore==0.1.72 \
+    --hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
+    --hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
+    --hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
+    --hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
+    --hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
+    --hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
+    --hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
+    --hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
+    --hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
+    --hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
+    --hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
+    --hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
+    --hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
+    --hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
+    --hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
+    --hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
+    --hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
+    --hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
+    --hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
+    --hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
+    --hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
+    # via -r build/test-requirements.txt
 typing-extensions==4.12.2 \
     --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
     --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt
index dfefaf042..bb10b1adc 100644
--- a/build/requirements_lock_3_13_ft.txt
+++ b/build/requirements_lock_3_13_ft.txt
@@ -588,6 +588,29 @@ sortedcontainers==2.4.0 \
     --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
     --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
     # via hypothesis
+tensorstore==0.1.72 \
+    --hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
+    --hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
+    --hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
+    --hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
+    --hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
+    --hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
+    --hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
+    --hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
+    --hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
+    --hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
+    --hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
+    --hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
+    --hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
+    --hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
+    --hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
+    --hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
+    --hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
+    --hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
+    --hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
+    --hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
+    --hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
+    # via -r build/test-requirements.txt
 typing-extensions==4.12.2 \
     --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
     --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
diff --git a/build/test-requirements.txt b/build/test-requirements.txt
index 19d713532..918bd4318 100644
--- a/build/test-requirements.txt
+++ b/build/test-requirements.txt
@@ -20,3 +20,4 @@ matplotlib~=3.8.4; python_version=="3.10"
 matplotlib; python_version>="3.11"
 opt-einsum
 auditwheel
+tensorstore==0.1.72
diff --git a/build_wheel.py b/build_wheel.py
new file mode 100644
index 000000000..ec93b4155
--- /dev/null
+++ b/build_wheel.py
@@ -0,0 +1,100 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Script that builds a JAX wheel, intended to be run via bazel run as part
+# of the JAX build process.
+
+import argparse
+import os
+import pathlib
+import shutil
+import tempfile
+
+from jaxlib.tools import build_utils
+
+parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
+parser.add_argument(
+    "--sources_path",
+    default=None,
+    help=(
+        "Path in which the wheel's sources should be prepared. Optional. If "
+        "omitted, a temporary directory will be used."
+    ),
+)
+parser.add_argument(
+    "--output_path",
+    default=None,
+    required=True,
+    help="Path to which the output wheel should be written. Required.",
+)
+parser.add_argument(
+    "--jaxlib_git_hash",
+    default="",
+    required=True,
+    help="Git hash. Empty if unknown. Optional.",
+)
+parser.add_argument(
+    "--srcs", help="source files for the wheel", action="append"
+)
+args = parser.parse_args()
+
+
+def copy_file(
+    src_file: str,
+    dst_dir: str,
+) -> None:
+  """Copy a file to the destination directory.
+
+  Args:
+    src_file: file to be copied
+    dst_dir: destination directory
+  """
+
+  dest_dir_path = os.path.join(dst_dir, os.path.dirname(src_file))
+  os.makedirs(dest_dir_path, exist_ok=True)
+  shutil.copy(src_file, dest_dir_path)
+  os.chmod(os.path.join(dst_dir, src_file), 0o644)
+
+
+def prepare_srcs(deps: list[str], srcs_dir: str) -> None:
+  """Filter the sources and copy them to the destination directory.
+
+  Args:
+    deps: a list of paths to files.
+    srcs_dir: target directory where files are copied to.
+  """
+
+  for file in deps:
+    if not (file.startswith("bazel-out") or file.startswith("external")):
+      copy_file(file, srcs_dir)
+
+
+tmpdir = None
+sources_path = args.sources_path
+if sources_path is None:
+  tmpdir = tempfile.TemporaryDirectory(prefix="jax")
+  sources_path = tmpdir.name
+
+try:
+  os.makedirs(args.output_path, exist_ok=True)
+  prepare_srcs(args.srcs, pathlib.Path(sources_path))
+  build_utils.build_wheel(
+      sources_path,
+      args.output_path,
+      package_name="jax",
+      git_hash=args.jaxlib_git_hash,
+  )
+finally:
+  if tmpdir:
+    tmpdir.cleanup()
diff --git a/jax/BUILD b/jax/BUILD
index 173b519bc..7736b6de5 100644
--- a/jax/BUILD
+++ b/jax/BUILD
@@ -61,6 +61,7 @@ config_setting(
 exports_files([
     "LICENSE",
     "version.py",
+    "py.typed",
 ])
 
 exports_files(
@@ -117,7 +118,6 @@ py_library(
     # JAX does provide some public test utilities (see jax/test_util.py);
     # these are available in jax.test_util via the standard :jax target.
     name = "test_util",
-    testonly = 1,
     srcs = [
         "_src/test_util.py",
         "_src/test_warning_util.py",
@@ -134,8 +134,8 @@ py_library(
 # TODO(necula): break the internal_test_util into smaller build targets.
 py_library(
     name = "internal_test_util",
-    testonly = 1,
     srcs = [
+        "_src/internal_test_util/__init__.py",
         "_src/internal_test_util/deprecation_module.py",
         "_src/internal_test_util/lax_test_util.py",
     ] + glob(
@@ -151,7 +151,6 @@ py_library(
 
 py_library(
     name = "internal_test_harnesses",
-    testonly = 1,
     srcs = ["_src/internal_test_util/test_harnesses.py"],
     visibility = [":internal"] + jax_internal_test_harnesses_visibility,
     deps = [
@@ -165,7 +164,6 @@ py_library(
 
 py_library(
     name = "internal_export_back_compat_test_util",
-    testonly = 1,
     srcs = ["_src/internal_test_util/export_back_compat_test_util.py"],
     visibility = [
         ":internal",
@@ -681,7 +679,6 @@ pytype_strict_library(
 
 pytype_strict_library(
     name = "pallas_experimental_gpu_ops",
-    testonly = True,
     srcs = ["//jax/experimental/pallas/ops/gpu:mgpu_ops"],
     visibility = [
         ":mosaic_gpu_users",
@@ -1121,7 +1118,6 @@ pytype_library(
 
 pytype_library(
     name = "sparse_test_util",
-    testonly = 1,
     srcs = [
         "experimental/sparse/test_util.py",
     ],
@@ -1183,6 +1179,7 @@ pytype_library(
 pytype_library(
     name = "compilation_cache",
     srcs = [
+        "experimental/compilation_cache/__init__.py",
         "experimental/compilation_cache/compilation_cache.py",
     ],
     visibility = ["//visibility:public"],
diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD
index d239fba98..3ef324372 100644
--- a/jax/_src/pallas/mosaic/BUILD
+++ b/jax/_src/pallas/mosaic/BUILD
@@ -26,7 +26,10 @@ package(
 
 py_library(
     name = "core",
-    srcs = ["core.py"],
+    srcs = [
+        "__init__.py",
+        "core.py",
+    ],
     deps = [
         "//jax",
         "//jax/_src/pallas",
diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD
index 84fae3913..cde2aadd6 100644
--- a/jax/_src/pallas/triton/BUILD
+++ b/jax/_src/pallas/triton/BUILD
@@ -29,7 +29,10 @@ package(
 
 pytype_strict_library(
     name = "core",
-    srcs = ["core.py"],
+    srcs = [
+        "__init__.py",
+        "core.py",
+    ],
     deps = ["//jax/_src/pallas"],
 )
 
diff --git a/jax/experimental/array_serialization/BUILD b/jax/experimental/array_serialization/BUILD
new file mode 100644
index 000000000..24eba04aa
--- /dev/null
+++ b/jax/experimental/array_serialization/BUILD
@@ -0,0 +1,41 @@
+# Copyright 2025 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load(
+    "//jaxlib:jax.bzl",
+    "py_deps",
+    "pytype_library",
+)
+
+licenses(["notice"])
+
+package(
+    default_applicable_licenses = [],
+)
+
+pytype_library(
+    name = "serialization",
+    srcs = [
+        "__init__.py",
+        "serialization.py",
+    ],
+    visibility = ["//visibility:public"],
+    deps = [
+        "//jax",
+    ] + py_deps([
+        "numpy",
+        "tensorstore",
+        "absl/logging",
+    ]),
+)
diff --git a/jax/extend/mlir/BUILD b/jax/extend/mlir/BUILD
index 8b8304282..17d5aab3a 100644
--- a/jax/extend/mlir/BUILD
+++ b/jax/extend/mlir/BUILD
@@ -23,6 +23,15 @@ package(
     default_visibility = ["//jax:jax_extend_users"],
 )
 
+pytype_strict_library(
+    name = "mlir",
+    srcs = ["__init__.py"],
+    deps = [
+        ":ir",
+        ":pass_manager",
+    ],
+)
+
 pytype_strict_library(
     name = "ir",
     srcs = ["ir.py"],
diff --git a/jax/extend/mlir/dialects/BUILD b/jax/extend/mlir/dialects/BUILD
index 7bd9e95b0..75275b45b 100644
--- a/jax/extend/mlir/dialects/BUILD
+++ b/jax/extend/mlir/dialects/BUILD
@@ -23,6 +23,24 @@ package(
     default_visibility = ["//jax:jax_extend_users"],
 )
 
+pytype_strict_library(
+    name = "dialects",
+    srcs = ["__init__.py"],
+    deps = [
+        ":arithmetic_dialect",
+        ":builtin_dialect",
+        ":chlo_dialect",
+        ":func_dialect",
+        ":math_dialect",
+        ":memref_dialect",
+        ":scf_dialect",
+        ":sdy_dialect",
+        ":sparse_tensor_dialect",
+        ":stablehlo_dialect",
+        ":vector_dialect",
+    ],
+)
+
 pytype_strict_library(
     name = "arithmetic_dialect",
     srcs = ["arith.py"],
diff --git a/jax/tools/BUILD b/jax/tools/BUILD
index 3e0a95029..e9a870351 100644
--- a/jax/tools/BUILD
+++ b/jax/tools/BUILD
@@ -25,9 +25,26 @@ package(
     default_visibility = ["//visibility:public"],
 )
 
+py_library(
+    name = "pgo_nsys_converter",
+    srcs = [
+        "pgo_nsys_converter.py",
+    ],
+)
+
+py_library(
+    name = "colab_tpu",
+    srcs = [
+        "colab_tpu.py",
+    ],
+)
+
 py_library(
     name = "jax_to_ir",
-    srcs = ["jax_to_ir.py"],
+    srcs = [
+        "__init__.py",
+        "jax_to_ir.py",
+    ],
     tags = [
         "ignore_for_dep=third_party.py.jax.experimental.jax2tf",
         "ignore_for_dep=third_party.py.tensorflow",
diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl
index e42f1e311..f7f39598f 100644
--- a/jaxlib/jax.bzl
+++ b/jaxlib/jax.bzl
@@ -87,6 +87,7 @@ _py_deps = {
     "ml_dtypes": ["@pypi_ml_dtypes//:pkg"],
     "numpy": ["@pypi_numpy//:pkg"],
     "scipy": ["@pypi_scipy//:pkg"],
+    "tensorstore": ["@pypi_tensorstore//:pkg"],
     "tensorflow_core": [],
     "torch": [],
     "zstandard": get_zstandard(),
@@ -125,12 +126,15 @@ def pytype_library(name, pytype_srcs = None, **kwargs):
     _ = pytype_srcs  # @unused
     native.py_library(name = name, **kwargs)
 
-def pytype_strict_library(name, pytype_srcs = None, **kwargs):
-    _ = pytype_srcs  # @unused
-    native.py_library(name = name, **kwargs)
+def pytype_strict_library(name, pytype_srcs = [], **kwargs):
+    data = pytype_srcs + (kwargs["data"] if "data" in kwargs else [])
+    new_kwargs = {k: v for k, v in kwargs.items() if k != "data"}
+    native.py_library(name = name, data = data, **new_kwargs)
 
 def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs):
-    lib_rule(name = name, **kwargs)
+    data = pytype_srcs + (kwargs["data"] if "data" in kwargs else [])
+    new_kwargs = {k: v for k, v in kwargs.items() if k != "data"}
+    lib_rule(name = name, data = data, **new_kwargs)
 
 def py_extension(name, srcs, copts, deps, linkopts = []):
     nanobind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name)
@@ -321,8 +325,8 @@ def jax_generate_backend_suites(backends = []):
         tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
     )
 
-def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_version):
-    if no_abi:
+def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_name, cpu_name, wheel_version):
+    if no_abi or platform_independent:
         wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl"
     else:
         wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl"
@@ -332,7 +336,9 @@ def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_ve
         python_version = python_version,
         major_python_version = python_version[0],
         wheel_version = wheel_version,
-        wheel_platform_tag = "_".join(PLATFORM_TAGS_DICT[platform_name, cpu_name]),
+        wheel_platform_tag = "any" if platform_independent else "_".join(
+            PLATFORM_TAGS_DICT[platform_name, cpu_name],
+        ),
     )
 
 def _jax_wheel_impl(ctx):
@@ -360,10 +366,13 @@ def _jax_wheel_impl(ctx):
         env["JAX_RELEASE"] = "1"
 
     cpu = ctx.attr.cpu
+    no_abi = ctx.attr.no_abi
+    platform_independent = ctx.attr.platform_independent
     platform_name = ctx.attr.platform_name
     wheel_name = _get_full_wheel_name(
         package_name = ctx.attr.wheel_name,
-        no_abi = ctx.attr.no_abi,
+        no_abi = no_abi,
+        platform_independent = platform_independent,
         platform_name = platform_name,
         cpu_name = cpu,
         wheel_version = full_wheel_version,
@@ -373,7 +382,8 @@ def _jax_wheel_impl(ctx):
     wheel_dir = output_file.path[:output_file.path.rfind("/")]
 
     args.add("--output_path", wheel_dir)  # required argument
-    args.add("--cpu", cpu)  # required argument
+    if not platform_independent:
+        args.add("--cpu", cpu)
     args.add("--jaxlib_git_hash", git_hash)  # required argument
 
     if ctx.attr.enable_cuda:
@@ -389,14 +399,21 @@ def _jax_wheel_impl(ctx):
     if ctx.attr.skip_gpu_kernels:
         args.add("--skip_gpu_kernels")
 
+    srcs = []
+    for src in ctx.attr.source_files:
+        for f in src.files.to_list():
+            srcs.append(f)
+            args.add("--srcs=%s" % (f.path))
+
     args.set_param_file_format("flag_per_line")
     args.use_param_file("@%s", use_always = False)
     ctx.actions.run(
         arguments = [args],
-        inputs = [],
+        inputs = srcs,
         outputs = [output_file],
         executable = executable,
         env = env,
+        mnemonic = "BuildJaxWheel",
     )
 
     return [DefaultInfo(files = depset(direct = [output_file]))]
@@ -411,9 +428,11 @@ _jax_wheel = rule(
         ),
         "wheel_name": attr.string(mandatory = True),
         "no_abi": attr.bool(default = False),
+        "platform_independent": attr.bool(default = False),
         "cpu": attr.string(mandatory = True),
         "platform_name": attr.string(mandatory = True),
         "git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")),
+        "source_files": attr.label_list(allow_files = True),
         "output_path": attr.label(default = Label("//jaxlib/tools:output_path")),
         "enable_cuda": attr.bool(default = False),
         # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
@@ -427,7 +446,15 @@ _jax_wheel = rule(
     executable = False,
 )
 
-def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = False, platform_version = ""):
+def jax_wheel(
+        name,
+        wheel_binary,
+        wheel_name,
+        no_abi = False,
+        platform_independent = False,
+        enable_cuda = False,
+        platform_version = "",
+        source_files = []):
     """Create jax artifact wheels.
 
     Common artifact attributes are grouped within a single macro.
@@ -437,8 +464,10 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals
       wheel_binary: the binary to use to build the wheel
       wheel_name: the name of the wheel
       no_abi: whether to build a wheel without ABI
+      platform_independent: whether to build a wheel without platform tag
       enable_cuda: whether to build a cuda wheel
       platform_version: the cuda version to use for the wheel
+      source_files: the source files to include in the wheel
 
     Returns:
       A directory containing the wheel
@@ -448,6 +477,7 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals
         wheel_binary = wheel_binary,
         wheel_name = wheel_name,
         no_abi = no_abi,
+        platform_independent = platform_independent,
         enable_cuda = enable_cuda,
         platform_version = platform_version,
         # git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)`
@@ -465,6 +495,7 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals
             "//jaxlib/tools:arm64": "aarch64",
             "@platforms//cpu:x86_64": "x86_64",
         }),
+        source_files = source_files,
     )
 
 jax_test_file_visibility = []
diff --git a/tests/BUILD b/tests/BUILD
index 8241abfff..ad714240e 100644
--- a/tests/BUILD
+++ b/tests/BUILD
@@ -1650,6 +1650,18 @@ jax_py_test(
     ],
 )
 
+jax_multiplatform_test(
+    name = "serialization_test",
+    srcs = ["serialization_test.py"],
+    enable_configs = [
+        "tpu_v3_2x2",
+    ],
+    deps = [
+        "//jax:experimental",
+        "//jax/experimental/array_serialization:serialization",
+    ],
+)
+
 exports_files(
     [
         "api_test.py",
diff --git a/jax/experimental/array_serialization/serialization_test.py b/tests/serialization_test.py
similarity index 100%
rename from jax/experimental/array_serialization/serialization_test.py
rename to tests/serialization_test.py