From eb912ad0d9430f5057c192afa0c3cf60a5b84926 Mon Sep 17 00:00:00 2001 From: jax authors <google-ml-automation@google.com> Date: Tue, 25 Feb 2025 09:28:35 -0800 Subject: [PATCH] Create `jax` wheel build target. This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126) Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file). You can still build the `jax` wheel with `python3 -m build` command. Bazel `jax` wheel target: `//:jax_wheel` Environment variables combinations for creating wheels with different versions: * self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot` * release: `--repo_env=ML_WHEEL_TYPE=release` * release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1` * nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)` PiperOrigin-RevId: 730916743 --- BUILD | 86 +++++++++++++++ build/requirements_lock_3_10.txt | 23 ++++ build/requirements_lock_3_11.txt | 23 ++++ build/requirements_lock_3_12.txt | 23 ++++ build/requirements_lock_3_13.txt | 23 ++++ build/requirements_lock_3_13_ft.txt | 23 ++++ build/test-requirements.txt | 1 + build_wheel.py | 100 ++++++++++++++++++ jax/BUILD | 9 +- jax/_src/pallas/mosaic/BUILD | 5 +- jax/_src/pallas/triton/BUILD | 5 +- jax/experimental/array_serialization/BUILD | 41 +++++++ jax/extend/mlir/BUILD | 9 ++ jax/extend/mlir/dialects/BUILD | 18 ++++ jax/tools/BUILD | 19 +++- jaxlib/jax.bzl | 53 ++++++++-- tests/BUILD | 12 +++ .../serialization_test.py | 0 18 files changed, 453 insertions(+), 20 deletions(-) create mode 100644 BUILD create mode 100644 build_wheel.py create mode 100644 jax/experimental/array_serialization/BUILD rename {jax/experimental/array_serialization => tests}/serialization_test.py (100%) diff --git a/BUILD b/BUILD new file mode 100644 index 000000000..ec6c87166 --- /dev/null +++ b/BUILD @@ -0,0 +1,86 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@tsl//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps") +load( + "//jaxlib:jax.bzl", + "jax_wheel", +) + +collect_data_files( + name = "transitive_py_data", + deps = ["//jax"], +) + +transitive_py_deps( + name = "transitive_py_deps", + deps = [ + "//jax", + "//jax:compilation_cache", + "//jax:experimental", + "//jax:experimental_colocated_python", + "//jax:experimental_sparse", + "//jax:internal_export_back_compat_test_util", + "//jax:internal_test_harnesses", + "//jax:internal_test_util", + "//jax:lax_reference", + "//jax:pallas_experimental_gpu_ops", + "//jax:pallas_gpu_ops", + "//jax:pallas_mosaic_gpu", + "//jax:pallas_tpu_ops", + "//jax:pallas_triton", + "//jax:source_mapper", + "//jax:sparse_test_util", + "//jax:test_util", + "//jax/_src/lib", + "//jax/_src/pallas/mosaic_gpu", + "//jax/experimental/array_serialization:serialization", + "//jax/experimental/jax2tf", + "//jax/extend", + "//jax/extend:ifrt_programs", + "//jax/extend/mlir", + "//jax/extend/mlir/dialects", + "//jax/tools:colab_tpu", + "//jax/tools:jax_to_ir", + "//jax/tools:pgo_nsys_converter", + ], +) + +py_binary( + name = "build_wheel", + srcs = ["build_wheel.py"], + deps = [ + "//jaxlib/tools:build_utils", + "@pypi_build//:pkg", + "@pypi_setuptools//:pkg", + "@pypi_wheel//:pkg", + ], +) + +jax_wheel( + name = "jax_wheel", + platform_independent = True, + source_files = [ + ":transitive_py_data", + ":transitive_py_deps", + "//jax:py.typed", + "AUTHORS", + "LICENSE", + "README.md", + "pyproject.toml", + "setup.py", + ], + wheel_binary = ":build_wheel", + wheel_name = "jax", +) diff --git a/build/requirements_lock_3_10.txt b/build/requirements_lock_3_10.txt index ccffa247f..c346daea3 100644 --- a/build/requirements_lock_3_10.txt +++ b/build/requirements_lock_3_10.txt @@ -566,6 +566,29 @@ sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis +tensorstore==0.1.72 \ + --hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \ + --hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \ + --hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \ + --hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \ + --hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \ + --hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \ + --hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \ + --hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \ + --hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \ + --hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \ + --hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \ + --hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \ + --hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \ + --hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \ + --hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \ + --hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \ + --hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \ + --hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \ + --hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \ + --hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \ + --hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6 + # via -r build/test-requirements.txt tomli==2.0.1 \ --hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \ --hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f diff --git a/build/requirements_lock_3_11.txt b/build/requirements_lock_3_11.txt index 7f3ee61ff..faa641940 100644 --- a/build/requirements_lock_3_11.txt +++ b/build/requirements_lock_3_11.txt @@ -561,6 +561,29 @@ sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis +tensorstore==0.1.72 \ + --hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \ + --hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \ + --hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \ + --hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \ + --hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \ + --hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \ + --hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \ + --hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \ + --hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \ + --hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \ + --hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \ + --hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \ + --hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \ + --hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \ + --hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \ + --hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \ + --hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \ + --hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \ + --hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \ + --hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \ + --hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6 + # via -r build/test-requirements.txt typing-extensions==4.12.0rc1 \ --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe diff --git a/build/requirements_lock_3_12.txt b/build/requirements_lock_3_12.txt index bf22c3623..d1f8d3ad5 100644 --- a/build/requirements_lock_3_12.txt +++ b/build/requirements_lock_3_12.txt @@ -561,6 +561,29 @@ sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis +tensorstore==0.1.72 \ + --hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \ + --hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \ + --hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \ + --hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \ + --hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \ + --hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \ + --hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \ + --hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \ + --hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \ + --hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \ + --hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \ + --hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \ + --hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \ + --hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \ + --hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \ + --hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \ + --hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \ + --hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \ + --hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \ + --hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \ + --hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6 + # via -r build/test-requirements.txt typing-extensions==4.12.0rc1 \ --hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \ --hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 9fa78c062..4e551eeda 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -634,6 +634,29 @@ sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis +tensorstore==0.1.72 \ + --hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \ + --hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \ + --hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \ + --hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \ + --hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \ + --hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \ + --hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \ + --hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \ + --hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \ + --hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \ + --hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \ + --hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \ + --hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \ + --hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \ + --hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \ + --hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \ + --hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \ + --hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \ + --hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \ + --hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \ + --hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6 + # via -r build/test-requirements.txt typing-extensions==4.12.2 \ --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8 diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt index dfefaf042..bb10b1adc 100644 --- a/build/requirements_lock_3_13_ft.txt +++ b/build/requirements_lock_3_13_ft.txt @@ -588,6 +588,29 @@ sortedcontainers==2.4.0 \ --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis +tensorstore==0.1.72 \ + --hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \ + --hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \ + --hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \ + --hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \ + --hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \ + --hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \ + --hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \ + --hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \ + --hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \ + --hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \ + --hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \ + --hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \ + --hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \ + --hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \ + --hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \ + --hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \ + --hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \ + --hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \ + --hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \ + --hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \ + --hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6 + # via -r build/test-requirements.txt typing-extensions==4.12.2 \ --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8 diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 19d713532..918bd4318 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -20,3 +20,4 @@ matplotlib~=3.8.4; python_version=="3.10" matplotlib; python_version>="3.11" opt-einsum auditwheel +tensorstore==0.1.72 diff --git a/build_wheel.py b/build_wheel.py new file mode 100644 index 000000000..ec93b4155 --- /dev/null +++ b/build_wheel.py @@ -0,0 +1,100 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Script that builds a JAX wheel, intended to be run via bazel run as part +# of the JAX build process. + +import argparse +import os +import pathlib +import shutil +import tempfile + +from jaxlib.tools import build_utils + +parser = argparse.ArgumentParser(fromfile_prefix_chars="@") +parser.add_argument( + "--sources_path", + default=None, + help=( + "Path in which the wheel's sources should be prepared. Optional. If " + "omitted, a temporary directory will be used." + ), +) +parser.add_argument( + "--output_path", + default=None, + required=True, + help="Path to which the output wheel should be written. Required.", +) +parser.add_argument( + "--jaxlib_git_hash", + default="", + required=True, + help="Git hash. Empty if unknown. Optional.", +) +parser.add_argument( + "--srcs", help="source files for the wheel", action="append" +) +args = parser.parse_args() + + +def copy_file( + src_file: str, + dst_dir: str, +) -> None: + """Copy a file to the destination directory. + + Args: + src_file: file to be copied + dst_dir: destination directory + """ + + dest_dir_path = os.path.join(dst_dir, os.path.dirname(src_file)) + os.makedirs(dest_dir_path, exist_ok=True) + shutil.copy(src_file, dest_dir_path) + os.chmod(os.path.join(dst_dir, src_file), 0o644) + + +def prepare_srcs(deps: list[str], srcs_dir: str) -> None: + """Filter the sources and copy them to the destination directory. + + Args: + deps: a list of paths to files. + srcs_dir: target directory where files are copied to. + """ + + for file in deps: + if not (file.startswith("bazel-out") or file.startswith("external")): + copy_file(file, srcs_dir) + + +tmpdir = None +sources_path = args.sources_path +if sources_path is None: + tmpdir = tempfile.TemporaryDirectory(prefix="jax") + sources_path = tmpdir.name + +try: + os.makedirs(args.output_path, exist_ok=True) + prepare_srcs(args.srcs, pathlib.Path(sources_path)) + build_utils.build_wheel( + sources_path, + args.output_path, + package_name="jax", + git_hash=args.jaxlib_git_hash, + ) +finally: + if tmpdir: + tmpdir.cleanup() diff --git a/jax/BUILD b/jax/BUILD index 173b519bc..7736b6de5 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -61,6 +61,7 @@ config_setting( exports_files([ "LICENSE", "version.py", + "py.typed", ]) exports_files( @@ -117,7 +118,6 @@ py_library( # JAX does provide some public test utilities (see jax/test_util.py); # these are available in jax.test_util via the standard :jax target. name = "test_util", - testonly = 1, srcs = [ "_src/test_util.py", "_src/test_warning_util.py", @@ -134,8 +134,8 @@ py_library( # TODO(necula): break the internal_test_util into smaller build targets. py_library( name = "internal_test_util", - testonly = 1, srcs = [ + "_src/internal_test_util/__init__.py", "_src/internal_test_util/deprecation_module.py", "_src/internal_test_util/lax_test_util.py", ] + glob( @@ -151,7 +151,6 @@ py_library( py_library( name = "internal_test_harnesses", - testonly = 1, srcs = ["_src/internal_test_util/test_harnesses.py"], visibility = [":internal"] + jax_internal_test_harnesses_visibility, deps = [ @@ -165,7 +164,6 @@ py_library( py_library( name = "internal_export_back_compat_test_util", - testonly = 1, srcs = ["_src/internal_test_util/export_back_compat_test_util.py"], visibility = [ ":internal", @@ -681,7 +679,6 @@ pytype_strict_library( pytype_strict_library( name = "pallas_experimental_gpu_ops", - testonly = True, srcs = ["//jax/experimental/pallas/ops/gpu:mgpu_ops"], visibility = [ ":mosaic_gpu_users", @@ -1121,7 +1118,6 @@ pytype_library( pytype_library( name = "sparse_test_util", - testonly = 1, srcs = [ "experimental/sparse/test_util.py", ], @@ -1183,6 +1179,7 @@ pytype_library( pytype_library( name = "compilation_cache", srcs = [ + "experimental/compilation_cache/__init__.py", "experimental/compilation_cache/compilation_cache.py", ], visibility = ["//visibility:public"], diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index d239fba98..3ef324372 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -26,7 +26,10 @@ package( py_library( name = "core", - srcs = ["core.py"], + srcs = [ + "__init__.py", + "core.py", + ], deps = [ "//jax", "//jax/_src/pallas", diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index 84fae3913..cde2aadd6 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -29,7 +29,10 @@ package( pytype_strict_library( name = "core", - srcs = ["core.py"], + srcs = [ + "__init__.py", + "core.py", + ], deps = ["//jax/_src/pallas"], ) diff --git a/jax/experimental/array_serialization/BUILD b/jax/experimental/array_serialization/BUILD new file mode 100644 index 000000000..24eba04aa --- /dev/null +++ b/jax/experimental/array_serialization/BUILD @@ -0,0 +1,41 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//jaxlib:jax.bzl", + "py_deps", + "pytype_library", +) + +licenses(["notice"]) + +package( + default_applicable_licenses = [], +) + +pytype_library( + name = "serialization", + srcs = [ + "__init__.py", + "serialization.py", + ], + visibility = ["//visibility:public"], + deps = [ + "//jax", + ] + py_deps([ + "numpy", + "tensorstore", + "absl/logging", + ]), +) diff --git a/jax/extend/mlir/BUILD b/jax/extend/mlir/BUILD index 8b8304282..17d5aab3a 100644 --- a/jax/extend/mlir/BUILD +++ b/jax/extend/mlir/BUILD @@ -23,6 +23,15 @@ package( default_visibility = ["//jax:jax_extend_users"], ) +pytype_strict_library( + name = "mlir", + srcs = ["__init__.py"], + deps = [ + ":ir", + ":pass_manager", + ], +) + pytype_strict_library( name = "ir", srcs = ["ir.py"], diff --git a/jax/extend/mlir/dialects/BUILD b/jax/extend/mlir/dialects/BUILD index 7bd9e95b0..75275b45b 100644 --- a/jax/extend/mlir/dialects/BUILD +++ b/jax/extend/mlir/dialects/BUILD @@ -23,6 +23,24 @@ package( default_visibility = ["//jax:jax_extend_users"], ) +pytype_strict_library( + name = "dialects", + srcs = ["__init__.py"], + deps = [ + ":arithmetic_dialect", + ":builtin_dialect", + ":chlo_dialect", + ":func_dialect", + ":math_dialect", + ":memref_dialect", + ":scf_dialect", + ":sdy_dialect", + ":sparse_tensor_dialect", + ":stablehlo_dialect", + ":vector_dialect", + ], +) + pytype_strict_library( name = "arithmetic_dialect", srcs = ["arith.py"], diff --git a/jax/tools/BUILD b/jax/tools/BUILD index 3e0a95029..e9a870351 100644 --- a/jax/tools/BUILD +++ b/jax/tools/BUILD @@ -25,9 +25,26 @@ package( default_visibility = ["//visibility:public"], ) +py_library( + name = "pgo_nsys_converter", + srcs = [ + "pgo_nsys_converter.py", + ], +) + +py_library( + name = "colab_tpu", + srcs = [ + "colab_tpu.py", + ], +) + py_library( name = "jax_to_ir", - srcs = ["jax_to_ir.py"], + srcs = [ + "__init__.py", + "jax_to_ir.py", + ], tags = [ "ignore_for_dep=third_party.py.jax.experimental.jax2tf", "ignore_for_dep=third_party.py.tensorflow", diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index e42f1e311..f7f39598f 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -87,6 +87,7 @@ _py_deps = { "ml_dtypes": ["@pypi_ml_dtypes//:pkg"], "numpy": ["@pypi_numpy//:pkg"], "scipy": ["@pypi_scipy//:pkg"], + "tensorstore": ["@pypi_tensorstore//:pkg"], "tensorflow_core": [], "torch": [], "zstandard": get_zstandard(), @@ -125,12 +126,15 @@ def pytype_library(name, pytype_srcs = None, **kwargs): _ = pytype_srcs # @unused native.py_library(name = name, **kwargs) -def pytype_strict_library(name, pytype_srcs = None, **kwargs): - _ = pytype_srcs # @unused - native.py_library(name = name, **kwargs) +def pytype_strict_library(name, pytype_srcs = [], **kwargs): + data = pytype_srcs + (kwargs["data"] if "data" in kwargs else []) + new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} + native.py_library(name = name, data = data, **new_kwargs) def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs): - lib_rule(name = name, **kwargs) + data = pytype_srcs + (kwargs["data"] if "data" in kwargs else []) + new_kwargs = {k: v for k, v in kwargs.items() if k != "data"} + lib_rule(name = name, data = data, **new_kwargs) def py_extension(name, srcs, copts, deps, linkopts = []): nanobind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name) @@ -321,8 +325,8 @@ def jax_generate_backend_suites(backends = []): tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"], ) -def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_version): - if no_abi: +def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_name, cpu_name, wheel_version): + if no_abi or platform_independent: wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl" else: wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl" @@ -332,7 +336,9 @@ def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_ve python_version = python_version, major_python_version = python_version[0], wheel_version = wheel_version, - wheel_platform_tag = "_".join(PLATFORM_TAGS_DICT[platform_name, cpu_name]), + wheel_platform_tag = "any" if platform_independent else "_".join( + PLATFORM_TAGS_DICT[platform_name, cpu_name], + ), ) def _jax_wheel_impl(ctx): @@ -360,10 +366,13 @@ def _jax_wheel_impl(ctx): env["JAX_RELEASE"] = "1" cpu = ctx.attr.cpu + no_abi = ctx.attr.no_abi + platform_independent = ctx.attr.platform_independent platform_name = ctx.attr.platform_name wheel_name = _get_full_wheel_name( package_name = ctx.attr.wheel_name, - no_abi = ctx.attr.no_abi, + no_abi = no_abi, + platform_independent = platform_independent, platform_name = platform_name, cpu_name = cpu, wheel_version = full_wheel_version, @@ -373,7 +382,8 @@ def _jax_wheel_impl(ctx): wheel_dir = output_file.path[:output_file.path.rfind("/")] args.add("--output_path", wheel_dir) # required argument - args.add("--cpu", cpu) # required argument + if not platform_independent: + args.add("--cpu", cpu) args.add("--jaxlib_git_hash", git_hash) # required argument if ctx.attr.enable_cuda: @@ -389,14 +399,21 @@ def _jax_wheel_impl(ctx): if ctx.attr.skip_gpu_kernels: args.add("--skip_gpu_kernels") + srcs = [] + for src in ctx.attr.source_files: + for f in src.files.to_list(): + srcs.append(f) + args.add("--srcs=%s" % (f.path)) + args.set_param_file_format("flag_per_line") args.use_param_file("@%s", use_always = False) ctx.actions.run( arguments = [args], - inputs = [], + inputs = srcs, outputs = [output_file], executable = executable, env = env, + mnemonic = "BuildJaxWheel", ) return [DefaultInfo(files = depset(direct = [output_file]))] @@ -411,9 +428,11 @@ _jax_wheel = rule( ), "wheel_name": attr.string(mandatory = True), "no_abi": attr.bool(default = False), + "platform_independent": attr.bool(default = False), "cpu": attr.string(mandatory = True), "platform_name": attr.string(mandatory = True), "git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")), + "source_files": attr.label_list(allow_files = True), "output_path": attr.label(default = Label("//jaxlib/tools:output_path")), "enable_cuda": attr.bool(default = False), # A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string. @@ -427,7 +446,15 @@ _jax_wheel = rule( executable = False, ) -def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = False, platform_version = ""): +def jax_wheel( + name, + wheel_binary, + wheel_name, + no_abi = False, + platform_independent = False, + enable_cuda = False, + platform_version = "", + source_files = []): """Create jax artifact wheels. Common artifact attributes are grouped within a single macro. @@ -437,8 +464,10 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals wheel_binary: the binary to use to build the wheel wheel_name: the name of the wheel no_abi: whether to build a wheel without ABI + platform_independent: whether to build a wheel without platform tag enable_cuda: whether to build a cuda wheel platform_version: the cuda version to use for the wheel + source_files: the source files to include in the wheel Returns: A directory containing the wheel @@ -448,6 +477,7 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals wheel_binary = wheel_binary, wheel_name = wheel_name, no_abi = no_abi, + platform_independent = platform_independent, enable_cuda = enable_cuda, platform_version = platform_version, # git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)` @@ -465,6 +495,7 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals "//jaxlib/tools:arm64": "aarch64", "@platforms//cpu:x86_64": "x86_64", }), + source_files = source_files, ) jax_test_file_visibility = [] diff --git a/tests/BUILD b/tests/BUILD index 8241abfff..ad714240e 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1650,6 +1650,18 @@ jax_py_test( ], ) +jax_multiplatform_test( + name = "serialization_test", + srcs = ["serialization_test.py"], + enable_configs = [ + "tpu_v3_2x2", + ], + deps = [ + "//jax:experimental", + "//jax/experimental/array_serialization:serialization", + ], +) + exports_files( [ "api_test.py", diff --git a/jax/experimental/array_serialization/serialization_test.py b/tests/serialization_test.py similarity index 100% rename from jax/experimental/array_serialization/serialization_test.py rename to tests/serialization_test.py