mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Create jax
wheel build target.
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126) Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file). You can still build the `jax` wheel with `python3 -m build` command. Bazel `jax` wheel target: `//:jax_wheel` Environment variables combinations for creating wheels with different versions: * self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot` * release: `--repo_env=ML_WHEEL_TYPE=release` * release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1` * nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)` PiperOrigin-RevId: 730916743
This commit is contained in:
parent
7a162f2abc
commit
eb912ad0d9
86
BUILD
Normal file
86
BUILD
Normal file
@ -0,0 +1,86 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@tsl//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"jax_wheel",
|
||||
)
|
||||
|
||||
collect_data_files(
|
||||
name = "transitive_py_data",
|
||||
deps = ["//jax"],
|
||||
)
|
||||
|
||||
transitive_py_deps(
|
||||
name = "transitive_py_deps",
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:compilation_cache",
|
||||
"//jax:experimental",
|
||||
"//jax:experimental_colocated_python",
|
||||
"//jax:experimental_sparse",
|
||||
"//jax:internal_export_back_compat_test_util",
|
||||
"//jax:internal_test_harnesses",
|
||||
"//jax:internal_test_util",
|
||||
"//jax:lax_reference",
|
||||
"//jax:pallas_experimental_gpu_ops",
|
||||
"//jax:pallas_gpu_ops",
|
||||
"//jax:pallas_mosaic_gpu",
|
||||
"//jax:pallas_tpu_ops",
|
||||
"//jax:pallas_triton",
|
||||
"//jax:source_mapper",
|
||||
"//jax:sparse_test_util",
|
||||
"//jax:test_util",
|
||||
"//jax/_src/lib",
|
||||
"//jax/_src/pallas/mosaic_gpu",
|
||||
"//jax/experimental/array_serialization:serialization",
|
||||
"//jax/experimental/jax2tf",
|
||||
"//jax/extend",
|
||||
"//jax/extend:ifrt_programs",
|
||||
"//jax/extend/mlir",
|
||||
"//jax/extend/mlir/dialects",
|
||||
"//jax/tools:colab_tpu",
|
||||
"//jax/tools:jax_to_ir",
|
||||
"//jax/tools:pgo_nsys_converter",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "build_wheel",
|
||||
srcs = ["build_wheel.py"],
|
||||
deps = [
|
||||
"//jaxlib/tools:build_utils",
|
||||
"@pypi_build//:pkg",
|
||||
"@pypi_setuptools//:pkg",
|
||||
"@pypi_wheel//:pkg",
|
||||
],
|
||||
)
|
||||
|
||||
jax_wheel(
|
||||
name = "jax_wheel",
|
||||
platform_independent = True,
|
||||
source_files = [
|
||||
":transitive_py_data",
|
||||
":transitive_py_deps",
|
||||
"//jax:py.typed",
|
||||
"AUTHORS",
|
||||
"LICENSE",
|
||||
"README.md",
|
||||
"pyproject.toml",
|
||||
"setup.py",
|
||||
],
|
||||
wheel_binary = ":build_wheel",
|
||||
wheel_name = "jax",
|
||||
)
|
@ -566,6 +566,29 @@ sortedcontainers==2.4.0 \
|
||||
--hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
|
||||
--hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
|
||||
# via hypothesis
|
||||
tensorstore==0.1.72 \
|
||||
--hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
|
||||
--hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
|
||||
--hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
|
||||
--hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
|
||||
--hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
|
||||
--hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
|
||||
--hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
|
||||
--hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
|
||||
--hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
|
||||
--hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
|
||||
--hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
|
||||
--hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
|
||||
--hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
|
||||
--hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
|
||||
--hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
|
||||
--hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
|
||||
--hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
|
||||
--hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
|
||||
--hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
|
||||
--hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
|
||||
--hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
|
||||
# via -r build/test-requirements.txt
|
||||
tomli==2.0.1 \
|
||||
--hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \
|
||||
--hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f
|
||||
|
@ -561,6 +561,29 @@ sortedcontainers==2.4.0 \
|
||||
--hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
|
||||
--hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
|
||||
# via hypothesis
|
||||
tensorstore==0.1.72 \
|
||||
--hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
|
||||
--hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
|
||||
--hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
|
||||
--hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
|
||||
--hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
|
||||
--hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
|
||||
--hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
|
||||
--hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
|
||||
--hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
|
||||
--hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
|
||||
--hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
|
||||
--hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
|
||||
--hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
|
||||
--hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
|
||||
--hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
|
||||
--hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
|
||||
--hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
|
||||
--hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
|
||||
--hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
|
||||
--hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
|
||||
--hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
|
||||
# via -r build/test-requirements.txt
|
||||
typing-extensions==4.12.0rc1 \
|
||||
--hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \
|
||||
--hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe
|
||||
|
@ -561,6 +561,29 @@ sortedcontainers==2.4.0 \
|
||||
--hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
|
||||
--hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
|
||||
# via hypothesis
|
||||
tensorstore==0.1.72 \
|
||||
--hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
|
||||
--hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
|
||||
--hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
|
||||
--hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
|
||||
--hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
|
||||
--hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
|
||||
--hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
|
||||
--hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
|
||||
--hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
|
||||
--hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
|
||||
--hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
|
||||
--hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
|
||||
--hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
|
||||
--hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
|
||||
--hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
|
||||
--hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
|
||||
--hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
|
||||
--hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
|
||||
--hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
|
||||
--hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
|
||||
--hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
|
||||
# via -r build/test-requirements.txt
|
||||
typing-extensions==4.12.0rc1 \
|
||||
--hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \
|
||||
--hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe
|
||||
|
@ -634,6 +634,29 @@ sortedcontainers==2.4.0 \
|
||||
--hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
|
||||
--hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
|
||||
# via hypothesis
|
||||
tensorstore==0.1.72 \
|
||||
--hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
|
||||
--hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
|
||||
--hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
|
||||
--hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
|
||||
--hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
|
||||
--hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
|
||||
--hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
|
||||
--hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
|
||||
--hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
|
||||
--hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
|
||||
--hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
|
||||
--hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
|
||||
--hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
|
||||
--hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
|
||||
--hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
|
||||
--hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
|
||||
--hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
|
||||
--hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
|
||||
--hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
|
||||
--hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
|
||||
--hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
|
||||
# via -r build/test-requirements.txt
|
||||
typing-extensions==4.12.2 \
|
||||
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
|
||||
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
|
||||
|
@ -588,6 +588,29 @@ sortedcontainers==2.4.0 \
|
||||
--hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
|
||||
--hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
|
||||
# via hypothesis
|
||||
tensorstore==0.1.72 \
|
||||
--hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
|
||||
--hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
|
||||
--hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
|
||||
--hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
|
||||
--hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
|
||||
--hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
|
||||
--hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
|
||||
--hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
|
||||
--hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
|
||||
--hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
|
||||
--hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
|
||||
--hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
|
||||
--hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
|
||||
--hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
|
||||
--hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
|
||||
--hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
|
||||
--hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
|
||||
--hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
|
||||
--hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
|
||||
--hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
|
||||
--hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
|
||||
# via -r build/test-requirements.txt
|
||||
typing-extensions==4.12.2 \
|
||||
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
|
||||
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
|
||||
|
@ -20,3 +20,4 @@ matplotlib~=3.8.4; python_version=="3.10"
|
||||
matplotlib; python_version>="3.11"
|
||||
opt-einsum
|
||||
auditwheel
|
||||
tensorstore==0.1.72
|
||||
|
100
build_wheel.py
Normal file
100
build_wheel.py
Normal file
@ -0,0 +1,100 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Script that builds a JAX wheel, intended to be run via bazel run as part
|
||||
# of the JAX build process.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from jaxlib.tools import build_utils
|
||||
|
||||
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
|
||||
parser.add_argument(
|
||||
"--sources_path",
|
||||
default=None,
|
||||
help=(
|
||||
"Path in which the wheel's sources should be prepared. Optional. If "
|
||||
"omitted, a temporary directory will be used."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to which the output wheel should be written. Required.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--jaxlib_git_hash",
|
||||
default="",
|
||||
required=True,
|
||||
help="Git hash. Empty if unknown. Optional.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--srcs", help="source files for the wheel", action="append"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def copy_file(
|
||||
src_file: str,
|
||||
dst_dir: str,
|
||||
) -> None:
|
||||
"""Copy a file to the destination directory.
|
||||
|
||||
Args:
|
||||
src_file: file to be copied
|
||||
dst_dir: destination directory
|
||||
"""
|
||||
|
||||
dest_dir_path = os.path.join(dst_dir, os.path.dirname(src_file))
|
||||
os.makedirs(dest_dir_path, exist_ok=True)
|
||||
shutil.copy(src_file, dest_dir_path)
|
||||
os.chmod(os.path.join(dst_dir, src_file), 0o644)
|
||||
|
||||
|
||||
def prepare_srcs(deps: list[str], srcs_dir: str) -> None:
|
||||
"""Filter the sources and copy them to the destination directory.
|
||||
|
||||
Args:
|
||||
deps: a list of paths to files.
|
||||
srcs_dir: target directory where files are copied to.
|
||||
"""
|
||||
|
||||
for file in deps:
|
||||
if not (file.startswith("bazel-out") or file.startswith("external")):
|
||||
copy_file(file, srcs_dir)
|
||||
|
||||
|
||||
tmpdir = None
|
||||
sources_path = args.sources_path
|
||||
if sources_path is None:
|
||||
tmpdir = tempfile.TemporaryDirectory(prefix="jax")
|
||||
sources_path = tmpdir.name
|
||||
|
||||
try:
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
prepare_srcs(args.srcs, pathlib.Path(sources_path))
|
||||
build_utils.build_wheel(
|
||||
sources_path,
|
||||
args.output_path,
|
||||
package_name="jax",
|
||||
git_hash=args.jaxlib_git_hash,
|
||||
)
|
||||
finally:
|
||||
if tmpdir:
|
||||
tmpdir.cleanup()
|
@ -61,6 +61,7 @@ config_setting(
|
||||
exports_files([
|
||||
"LICENSE",
|
||||
"version.py",
|
||||
"py.typed",
|
||||
])
|
||||
|
||||
exports_files(
|
||||
@ -117,7 +118,6 @@ py_library(
|
||||
# JAX does provide some public test utilities (see jax/test_util.py);
|
||||
# these are available in jax.test_util via the standard :jax target.
|
||||
name = "test_util",
|
||||
testonly = 1,
|
||||
srcs = [
|
||||
"_src/test_util.py",
|
||||
"_src/test_warning_util.py",
|
||||
@ -134,8 +134,8 @@ py_library(
|
||||
# TODO(necula): break the internal_test_util into smaller build targets.
|
||||
py_library(
|
||||
name = "internal_test_util",
|
||||
testonly = 1,
|
||||
srcs = [
|
||||
"_src/internal_test_util/__init__.py",
|
||||
"_src/internal_test_util/deprecation_module.py",
|
||||
"_src/internal_test_util/lax_test_util.py",
|
||||
] + glob(
|
||||
@ -151,7 +151,6 @@ py_library(
|
||||
|
||||
py_library(
|
||||
name = "internal_test_harnesses",
|
||||
testonly = 1,
|
||||
srcs = ["_src/internal_test_util/test_harnesses.py"],
|
||||
visibility = [":internal"] + jax_internal_test_harnesses_visibility,
|
||||
deps = [
|
||||
@ -165,7 +164,6 @@ py_library(
|
||||
|
||||
py_library(
|
||||
name = "internal_export_back_compat_test_util",
|
||||
testonly = 1,
|
||||
srcs = ["_src/internal_test_util/export_back_compat_test_util.py"],
|
||||
visibility = [
|
||||
":internal",
|
||||
@ -681,7 +679,6 @@ pytype_strict_library(
|
||||
|
||||
pytype_strict_library(
|
||||
name = "pallas_experimental_gpu_ops",
|
||||
testonly = True,
|
||||
srcs = ["//jax/experimental/pallas/ops/gpu:mgpu_ops"],
|
||||
visibility = [
|
||||
":mosaic_gpu_users",
|
||||
@ -1121,7 +1118,6 @@ pytype_library(
|
||||
|
||||
pytype_library(
|
||||
name = "sparse_test_util",
|
||||
testonly = 1,
|
||||
srcs = [
|
||||
"experimental/sparse/test_util.py",
|
||||
],
|
||||
@ -1183,6 +1179,7 @@ pytype_library(
|
||||
pytype_library(
|
||||
name = "compilation_cache",
|
||||
srcs = [
|
||||
"experimental/compilation_cache/__init__.py",
|
||||
"experimental/compilation_cache/compilation_cache.py",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
|
@ -26,7 +26,10 @@ package(
|
||||
|
||||
py_library(
|
||||
name = "core",
|
||||
srcs = ["core.py"],
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"core.py",
|
||||
],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax/_src/pallas",
|
||||
|
@ -29,7 +29,10 @@ package(
|
||||
|
||||
pytype_strict_library(
|
||||
name = "core",
|
||||
srcs = ["core.py"],
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"core.py",
|
||||
],
|
||||
deps = ["//jax/_src/pallas"],
|
||||
)
|
||||
|
||||
|
41
jax/experimental/array_serialization/BUILD
Normal file
41
jax/experimental/array_serialization/BUILD
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"py_deps",
|
||||
"pytype_library",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "serialization",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"serialization.py",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//jax",
|
||||
] + py_deps([
|
||||
"numpy",
|
||||
"tensorstore",
|
||||
"absl/logging",
|
||||
]),
|
||||
)
|
@ -23,6 +23,15 @@ package(
|
||||
default_visibility = ["//jax:jax_extend_users"],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "mlir",
|
||||
srcs = ["__init__.py"],
|
||||
deps = [
|
||||
":ir",
|
||||
":pass_manager",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "ir",
|
||||
srcs = ["ir.py"],
|
||||
|
@ -23,6 +23,24 @@ package(
|
||||
default_visibility = ["//jax:jax_extend_users"],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "dialects",
|
||||
srcs = ["__init__.py"],
|
||||
deps = [
|
||||
":arithmetic_dialect",
|
||||
":builtin_dialect",
|
||||
":chlo_dialect",
|
||||
":func_dialect",
|
||||
":math_dialect",
|
||||
":memref_dialect",
|
||||
":scf_dialect",
|
||||
":sdy_dialect",
|
||||
":sparse_tensor_dialect",
|
||||
":stablehlo_dialect",
|
||||
":vector_dialect",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "arithmetic_dialect",
|
||||
srcs = ["arith.py"],
|
||||
|
@ -25,9 +25,26 @@ package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "pgo_nsys_converter",
|
||||
srcs = [
|
||||
"pgo_nsys_converter.py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "colab_tpu",
|
||||
srcs = [
|
||||
"colab_tpu.py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "jax_to_ir",
|
||||
srcs = ["jax_to_ir.py"],
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
"jax_to_ir.py",
|
||||
],
|
||||
tags = [
|
||||
"ignore_for_dep=third_party.py.jax.experimental.jax2tf",
|
||||
"ignore_for_dep=third_party.py.tensorflow",
|
||||
|
@ -87,6 +87,7 @@ _py_deps = {
|
||||
"ml_dtypes": ["@pypi_ml_dtypes//:pkg"],
|
||||
"numpy": ["@pypi_numpy//:pkg"],
|
||||
"scipy": ["@pypi_scipy//:pkg"],
|
||||
"tensorstore": ["@pypi_tensorstore//:pkg"],
|
||||
"tensorflow_core": [],
|
||||
"torch": [],
|
||||
"zstandard": get_zstandard(),
|
||||
@ -125,12 +126,15 @@ def pytype_library(name, pytype_srcs = None, **kwargs):
|
||||
_ = pytype_srcs # @unused
|
||||
native.py_library(name = name, **kwargs)
|
||||
|
||||
def pytype_strict_library(name, pytype_srcs = None, **kwargs):
|
||||
_ = pytype_srcs # @unused
|
||||
native.py_library(name = name, **kwargs)
|
||||
def pytype_strict_library(name, pytype_srcs = [], **kwargs):
|
||||
data = pytype_srcs + (kwargs["data"] if "data" in kwargs else [])
|
||||
new_kwargs = {k: v for k, v in kwargs.items() if k != "data"}
|
||||
native.py_library(name = name, data = data, **new_kwargs)
|
||||
|
||||
def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs):
|
||||
lib_rule(name = name, **kwargs)
|
||||
data = pytype_srcs + (kwargs["data"] if "data" in kwargs else [])
|
||||
new_kwargs = {k: v for k, v in kwargs.items() if k != "data"}
|
||||
lib_rule(name = name, data = data, **new_kwargs)
|
||||
|
||||
def py_extension(name, srcs, copts, deps, linkopts = []):
|
||||
nanobind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name)
|
||||
@ -321,8 +325,8 @@ def jax_generate_backend_suites(backends = []):
|
||||
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
|
||||
)
|
||||
|
||||
def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_version):
|
||||
if no_abi:
|
||||
def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_name, cpu_name, wheel_version):
|
||||
if no_abi or platform_independent:
|
||||
wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl"
|
||||
else:
|
||||
wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl"
|
||||
@ -332,7 +336,9 @@ def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_ve
|
||||
python_version = python_version,
|
||||
major_python_version = python_version[0],
|
||||
wheel_version = wheel_version,
|
||||
wheel_platform_tag = "_".join(PLATFORM_TAGS_DICT[platform_name, cpu_name]),
|
||||
wheel_platform_tag = "any" if platform_independent else "_".join(
|
||||
PLATFORM_TAGS_DICT[platform_name, cpu_name],
|
||||
),
|
||||
)
|
||||
|
||||
def _jax_wheel_impl(ctx):
|
||||
@ -360,10 +366,13 @@ def _jax_wheel_impl(ctx):
|
||||
env["JAX_RELEASE"] = "1"
|
||||
|
||||
cpu = ctx.attr.cpu
|
||||
no_abi = ctx.attr.no_abi
|
||||
platform_independent = ctx.attr.platform_independent
|
||||
platform_name = ctx.attr.platform_name
|
||||
wheel_name = _get_full_wheel_name(
|
||||
package_name = ctx.attr.wheel_name,
|
||||
no_abi = ctx.attr.no_abi,
|
||||
no_abi = no_abi,
|
||||
platform_independent = platform_independent,
|
||||
platform_name = platform_name,
|
||||
cpu_name = cpu,
|
||||
wheel_version = full_wheel_version,
|
||||
@ -373,7 +382,8 @@ def _jax_wheel_impl(ctx):
|
||||
wheel_dir = output_file.path[:output_file.path.rfind("/")]
|
||||
|
||||
args.add("--output_path", wheel_dir) # required argument
|
||||
args.add("--cpu", cpu) # required argument
|
||||
if not platform_independent:
|
||||
args.add("--cpu", cpu)
|
||||
args.add("--jaxlib_git_hash", git_hash) # required argument
|
||||
|
||||
if ctx.attr.enable_cuda:
|
||||
@ -389,14 +399,21 @@ def _jax_wheel_impl(ctx):
|
||||
if ctx.attr.skip_gpu_kernels:
|
||||
args.add("--skip_gpu_kernels")
|
||||
|
||||
srcs = []
|
||||
for src in ctx.attr.source_files:
|
||||
for f in src.files.to_list():
|
||||
srcs.append(f)
|
||||
args.add("--srcs=%s" % (f.path))
|
||||
|
||||
args.set_param_file_format("flag_per_line")
|
||||
args.use_param_file("@%s", use_always = False)
|
||||
ctx.actions.run(
|
||||
arguments = [args],
|
||||
inputs = [],
|
||||
inputs = srcs,
|
||||
outputs = [output_file],
|
||||
executable = executable,
|
||||
env = env,
|
||||
mnemonic = "BuildJaxWheel",
|
||||
)
|
||||
|
||||
return [DefaultInfo(files = depset(direct = [output_file]))]
|
||||
@ -411,9 +428,11 @@ _jax_wheel = rule(
|
||||
),
|
||||
"wheel_name": attr.string(mandatory = True),
|
||||
"no_abi": attr.bool(default = False),
|
||||
"platform_independent": attr.bool(default = False),
|
||||
"cpu": attr.string(mandatory = True),
|
||||
"platform_name": attr.string(mandatory = True),
|
||||
"git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")),
|
||||
"source_files": attr.label_list(allow_files = True),
|
||||
"output_path": attr.label(default = Label("//jaxlib/tools:output_path")),
|
||||
"enable_cuda": attr.bool(default = False),
|
||||
# A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
|
||||
@ -427,7 +446,15 @@ _jax_wheel = rule(
|
||||
executable = False,
|
||||
)
|
||||
|
||||
def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = False, platform_version = ""):
|
||||
def jax_wheel(
|
||||
name,
|
||||
wheel_binary,
|
||||
wheel_name,
|
||||
no_abi = False,
|
||||
platform_independent = False,
|
||||
enable_cuda = False,
|
||||
platform_version = "",
|
||||
source_files = []):
|
||||
"""Create jax artifact wheels.
|
||||
|
||||
Common artifact attributes are grouped within a single macro.
|
||||
@ -437,8 +464,10 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals
|
||||
wheel_binary: the binary to use to build the wheel
|
||||
wheel_name: the name of the wheel
|
||||
no_abi: whether to build a wheel without ABI
|
||||
platform_independent: whether to build a wheel without platform tag
|
||||
enable_cuda: whether to build a cuda wheel
|
||||
platform_version: the cuda version to use for the wheel
|
||||
source_files: the source files to include in the wheel
|
||||
|
||||
Returns:
|
||||
A directory containing the wheel
|
||||
@ -448,6 +477,7 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals
|
||||
wheel_binary = wheel_binary,
|
||||
wheel_name = wheel_name,
|
||||
no_abi = no_abi,
|
||||
platform_independent = platform_independent,
|
||||
enable_cuda = enable_cuda,
|
||||
platform_version = platform_version,
|
||||
# git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)`
|
||||
@ -465,6 +495,7 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals
|
||||
"//jaxlib/tools:arm64": "aarch64",
|
||||
"@platforms//cpu:x86_64": "x86_64",
|
||||
}),
|
||||
source_files = source_files,
|
||||
)
|
||||
|
||||
jax_test_file_visibility = []
|
||||
|
12
tests/BUILD
12
tests/BUILD
@ -1650,6 +1650,18 @@ jax_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "serialization_test",
|
||||
srcs = ["serialization_test.py"],
|
||||
enable_configs = [
|
||||
"tpu_v3_2x2",
|
||||
],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
"//jax/experimental/array_serialization:serialization",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"api_test.py",
|
||||
|
Loading…
x
Reference in New Issue
Block a user