Create jax wheel build target.

This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)

Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the `jax` wheel with `python3 -m build` command.

Bazel `jax` wheel target: `//:jax_wheel`

Environment variables combinations for creating wheels with different versions:
  * self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * release: `--repo_env=ML_WHEEL_TYPE=release`
  * release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
  * nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 730916743
This commit is contained in:
jax authors 2025-02-25 09:28:35 -08:00
parent 7a162f2abc
commit eb912ad0d9
18 changed files with 453 additions and 20 deletions

86
BUILD Normal file
View File

@ -0,0 +1,86 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@tsl//third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps")
load(
"//jaxlib:jax.bzl",
"jax_wheel",
)
collect_data_files(
name = "transitive_py_data",
deps = ["//jax"],
)
transitive_py_deps(
name = "transitive_py_deps",
deps = [
"//jax",
"//jax:compilation_cache",
"//jax:experimental",
"//jax:experimental_colocated_python",
"//jax:experimental_sparse",
"//jax:internal_export_back_compat_test_util",
"//jax:internal_test_harnesses",
"//jax:internal_test_util",
"//jax:lax_reference",
"//jax:pallas_experimental_gpu_ops",
"//jax:pallas_gpu_ops",
"//jax:pallas_mosaic_gpu",
"//jax:pallas_tpu_ops",
"//jax:pallas_triton",
"//jax:source_mapper",
"//jax:sparse_test_util",
"//jax:test_util",
"//jax/_src/lib",
"//jax/_src/pallas/mosaic_gpu",
"//jax/experimental/array_serialization:serialization",
"//jax/experimental/jax2tf",
"//jax/extend",
"//jax/extend:ifrt_programs",
"//jax/extend/mlir",
"//jax/extend/mlir/dialects",
"//jax/tools:colab_tpu",
"//jax/tools:jax_to_ir",
"//jax/tools:pgo_nsys_converter",
],
)
py_binary(
name = "build_wheel",
srcs = ["build_wheel.py"],
deps = [
"//jaxlib/tools:build_utils",
"@pypi_build//:pkg",
"@pypi_setuptools//:pkg",
"@pypi_wheel//:pkg",
],
)
jax_wheel(
name = "jax_wheel",
platform_independent = True,
source_files = [
":transitive_py_data",
":transitive_py_deps",
"//jax:py.typed",
"AUTHORS",
"LICENSE",
"README.md",
"pyproject.toml",
"setup.py",
],
wheel_binary = ":build_wheel",
wheel_name = "jax",
)

View File

@ -566,6 +566,29 @@ sortedcontainers==2.4.0 \
--hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
--hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
# via hypothesis
tensorstore==0.1.72 \
--hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
--hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
--hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
--hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
--hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
--hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
--hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
--hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
--hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
--hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
--hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
--hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
--hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
--hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
--hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
--hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
--hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
--hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
--hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
--hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
--hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
# via -r build/test-requirements.txt
tomli==2.0.1 \
--hash=sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc \
--hash=sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f

View File

@ -561,6 +561,29 @@ sortedcontainers==2.4.0 \
--hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
--hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
# via hypothesis
tensorstore==0.1.72 \
--hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
--hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
--hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
--hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
--hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
--hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
--hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
--hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
--hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
--hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
--hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
--hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
--hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
--hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
--hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
--hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
--hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
--hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
--hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
--hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
--hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
# via -r build/test-requirements.txt
typing-extensions==4.12.0rc1 \
--hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \
--hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe

View File

@ -561,6 +561,29 @@ sortedcontainers==2.4.0 \
--hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
--hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
# via hypothesis
tensorstore==0.1.72 \
--hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
--hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
--hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
--hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
--hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
--hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
--hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
--hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
--hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
--hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
--hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
--hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
--hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
--hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
--hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
--hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
--hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
--hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
--hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
--hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
--hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
# via -r build/test-requirements.txt
typing-extensions==4.12.0rc1 \
--hash=sha256:be199d06d8f09ca2c9425e3aa04a9afba33e892fe079dea959e72df7f8442343 \
--hash=sha256:f933a7b288a919ca97adbff656e52ff81f7ff25d98a2aabb9355ca4090f772fe

View File

@ -634,6 +634,29 @@ sortedcontainers==2.4.0 \
--hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
--hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
# via hypothesis
tensorstore==0.1.72 \
--hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
--hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
--hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
--hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
--hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
--hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
--hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
--hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
--hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
--hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
--hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
--hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
--hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
--hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
--hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
--hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
--hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
--hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
--hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
--hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
--hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
# via -r build/test-requirements.txt
typing-extensions==4.12.2 \
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8

View File

@ -588,6 +588,29 @@ sortedcontainers==2.4.0 \
--hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \
--hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0
# via hypothesis
tensorstore==0.1.72 \
--hash=sha256:08c5318535aac5e20e247c6e9b43f5887b2293f548de7279650bc73804ccf3ed \
--hash=sha256:0cd951e593a17babbbde1410cfadb4a04e1cddfa5ace0de5ccb41029223f96b9 \
--hash=sha256:170172b698fefb4b5507c6cb339ca0b75d56d12ba6a43d9569c61800c1eeb121 \
--hash=sha256:2fdfa0118be0721c110bcbe7e464758f78d3e14ee8c30a911eb8f4465e6c2e81 \
--hash=sha256:4a6825cdb6751663ca0bd9abd528ea354ad2199f549bf1f36feac79a6c06efe2 \
--hash=sha256:599cc7b26b0c96373e89ff5bcf9b76e832802169229680bef985b10011f9bae7 \
--hash=sha256:5d410c879dc4b34036ec38e20ff05c7e3b0ad5d1eb595412b27a9dbb5e435035 \
--hash=sha256:5ed6fe937b0433b573c3d6805d0759d33ccc24aa2aba720e4b8ba689c2f9775f \
--hash=sha256:66c0658689243af0825fff222fb56fdf05a8553bcb3b471dbf18830161302986 \
--hash=sha256:721d599db0113d75ab6ba1365989bbaf2ab752d7a6268f975c8bfd3a8eb6084b \
--hash=sha256:763d7f6898711783f199c8226a9c0b259546f5c6d9b4dc0ad3c9e39627060022 \
--hash=sha256:7c9413f8318a4fa259ec5325f569c0759bccee936df44bd2f7bb35c8afdcdfc8 \
--hash=sha256:9113d3fcf78c1366688aa90ee7efdc86b57962ea72276944cc57e916a6180749 \
--hash=sha256:92fac5e2cbc90e5ca8fc72c5bf112816d981e266a3cf9fb1681ba8b3f59537ef \
--hash=sha256:9c3a36f681ffcc104ba931d471447e8901e64e8cc6913b61792870ff59529961 \
--hash=sha256:a41b4fe0603943d23472619a8ada70b8d2c9458747fad88b0ce7b29f1ccf4e74 \
--hash=sha256:a7e7b02da26ca5c95b3c613efd0fe10c082dfa4dc3e9818fefc69e30fe70ea1e \
--hash=sha256:b71134b85f540e17a1ae65da1fb906781b7470ef0ed71d98d29459325897f574 \
--hash=sha256:c0f722218f494b1631dbec451b9863f579054e27da2f39aab418db4493694abe \
--hash=sha256:d5dced3f367308e9fa8e7b72e9e57a4c491fa47c066e035ac33421e2b2408e3f \
--hash=sha256:ed916b9aeca242a3f367679f65ba376149251ebb28b873becd76c73b688399b6
# via -r build/test-requirements.txt
typing-extensions==4.12.2 \
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8

View File

@ -20,3 +20,4 @@ matplotlib~=3.8.4; python_version=="3.10"
matplotlib; python_version>="3.11"
opt-einsum
auditwheel
tensorstore==0.1.72

100
build_wheel.py Normal file
View File

@ -0,0 +1,100 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Script that builds a JAX wheel, intended to be run via bazel run as part
# of the JAX build process.
import argparse
import os
import pathlib
import shutil
import tempfile
from jaxlib.tools import build_utils
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
parser.add_argument(
"--sources_path",
default=None,
help=(
"Path in which the wheel's sources should be prepared. Optional. If "
"omitted, a temporary directory will be used."
),
)
parser.add_argument(
"--output_path",
default=None,
required=True,
help="Path to which the output wheel should be written. Required.",
)
parser.add_argument(
"--jaxlib_git_hash",
default="",
required=True,
help="Git hash. Empty if unknown. Optional.",
)
parser.add_argument(
"--srcs", help="source files for the wheel", action="append"
)
args = parser.parse_args()
def copy_file(
src_file: str,
dst_dir: str,
) -> None:
"""Copy a file to the destination directory.
Args:
src_file: file to be copied
dst_dir: destination directory
"""
dest_dir_path = os.path.join(dst_dir, os.path.dirname(src_file))
os.makedirs(dest_dir_path, exist_ok=True)
shutil.copy(src_file, dest_dir_path)
os.chmod(os.path.join(dst_dir, src_file), 0o644)
def prepare_srcs(deps: list[str], srcs_dir: str) -> None:
"""Filter the sources and copy them to the destination directory.
Args:
deps: a list of paths to files.
srcs_dir: target directory where files are copied to.
"""
for file in deps:
if not (file.startswith("bazel-out") or file.startswith("external")):
copy_file(file, srcs_dir)
tmpdir = None
sources_path = args.sources_path
if sources_path is None:
tmpdir = tempfile.TemporaryDirectory(prefix="jax")
sources_path = tmpdir.name
try:
os.makedirs(args.output_path, exist_ok=True)
prepare_srcs(args.srcs, pathlib.Path(sources_path))
build_utils.build_wheel(
sources_path,
args.output_path,
package_name="jax",
git_hash=args.jaxlib_git_hash,
)
finally:
if tmpdir:
tmpdir.cleanup()

View File

@ -61,6 +61,7 @@ config_setting(
exports_files([
"LICENSE",
"version.py",
"py.typed",
])
exports_files(
@ -117,7 +118,6 @@ py_library(
# JAX does provide some public test utilities (see jax/test_util.py);
# these are available in jax.test_util via the standard :jax target.
name = "test_util",
testonly = 1,
srcs = [
"_src/test_util.py",
"_src/test_warning_util.py",
@ -134,8 +134,8 @@ py_library(
# TODO(necula): break the internal_test_util into smaller build targets.
py_library(
name = "internal_test_util",
testonly = 1,
srcs = [
"_src/internal_test_util/__init__.py",
"_src/internal_test_util/deprecation_module.py",
"_src/internal_test_util/lax_test_util.py",
] + glob(
@ -151,7 +151,6 @@ py_library(
py_library(
name = "internal_test_harnesses",
testonly = 1,
srcs = ["_src/internal_test_util/test_harnesses.py"],
visibility = [":internal"] + jax_internal_test_harnesses_visibility,
deps = [
@ -165,7 +164,6 @@ py_library(
py_library(
name = "internal_export_back_compat_test_util",
testonly = 1,
srcs = ["_src/internal_test_util/export_back_compat_test_util.py"],
visibility = [
":internal",
@ -681,7 +679,6 @@ pytype_strict_library(
pytype_strict_library(
name = "pallas_experimental_gpu_ops",
testonly = True,
srcs = ["//jax/experimental/pallas/ops/gpu:mgpu_ops"],
visibility = [
":mosaic_gpu_users",
@ -1121,7 +1118,6 @@ pytype_library(
pytype_library(
name = "sparse_test_util",
testonly = 1,
srcs = [
"experimental/sparse/test_util.py",
],
@ -1183,6 +1179,7 @@ pytype_library(
pytype_library(
name = "compilation_cache",
srcs = [
"experimental/compilation_cache/__init__.py",
"experimental/compilation_cache/compilation_cache.py",
],
visibility = ["//visibility:public"],

View File

@ -26,7 +26,10 @@ package(
py_library(
name = "core",
srcs = ["core.py"],
srcs = [
"__init__.py",
"core.py",
],
deps = [
"//jax",
"//jax/_src/pallas",

View File

@ -29,7 +29,10 @@ package(
pytype_strict_library(
name = "core",
srcs = ["core.py"],
srcs = [
"__init__.py",
"core.py",
],
deps = ["//jax/_src/pallas"],
)

View File

@ -0,0 +1,41 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//jaxlib:jax.bzl",
"py_deps",
"pytype_library",
)
licenses(["notice"])
package(
default_applicable_licenses = [],
)
pytype_library(
name = "serialization",
srcs = [
"__init__.py",
"serialization.py",
],
visibility = ["//visibility:public"],
deps = [
"//jax",
] + py_deps([
"numpy",
"tensorstore",
"absl/logging",
]),
)

View File

@ -23,6 +23,15 @@ package(
default_visibility = ["//jax:jax_extend_users"],
)
pytype_strict_library(
name = "mlir",
srcs = ["__init__.py"],
deps = [
":ir",
":pass_manager",
],
)
pytype_strict_library(
name = "ir",
srcs = ["ir.py"],

View File

@ -23,6 +23,24 @@ package(
default_visibility = ["//jax:jax_extend_users"],
)
pytype_strict_library(
name = "dialects",
srcs = ["__init__.py"],
deps = [
":arithmetic_dialect",
":builtin_dialect",
":chlo_dialect",
":func_dialect",
":math_dialect",
":memref_dialect",
":scf_dialect",
":sdy_dialect",
":sparse_tensor_dialect",
":stablehlo_dialect",
":vector_dialect",
],
)
pytype_strict_library(
name = "arithmetic_dialect",
srcs = ["arith.py"],

View File

@ -25,9 +25,26 @@ package(
default_visibility = ["//visibility:public"],
)
py_library(
name = "pgo_nsys_converter",
srcs = [
"pgo_nsys_converter.py",
],
)
py_library(
name = "colab_tpu",
srcs = [
"colab_tpu.py",
],
)
py_library(
name = "jax_to_ir",
srcs = ["jax_to_ir.py"],
srcs = [
"__init__.py",
"jax_to_ir.py",
],
tags = [
"ignore_for_dep=third_party.py.jax.experimental.jax2tf",
"ignore_for_dep=third_party.py.tensorflow",

View File

@ -87,6 +87,7 @@ _py_deps = {
"ml_dtypes": ["@pypi_ml_dtypes//:pkg"],
"numpy": ["@pypi_numpy//:pkg"],
"scipy": ["@pypi_scipy//:pkg"],
"tensorstore": ["@pypi_tensorstore//:pkg"],
"tensorflow_core": [],
"torch": [],
"zstandard": get_zstandard(),
@ -125,12 +126,15 @@ def pytype_library(name, pytype_srcs = None, **kwargs):
_ = pytype_srcs # @unused
native.py_library(name = name, **kwargs)
def pytype_strict_library(name, pytype_srcs = None, **kwargs):
_ = pytype_srcs # @unused
native.py_library(name = name, **kwargs)
def pytype_strict_library(name, pytype_srcs = [], **kwargs):
data = pytype_srcs + (kwargs["data"] if "data" in kwargs else [])
new_kwargs = {k: v for k, v in kwargs.items() if k != "data"}
native.py_library(name = name, data = data, **new_kwargs)
def py_library_providing_imports_info(*, name, lib_rule = native.py_library, pytype_srcs = [], **kwargs):
lib_rule(name = name, **kwargs)
data = pytype_srcs + (kwargs["data"] if "data" in kwargs else [])
new_kwargs = {k: v for k, v in kwargs.items() if k != "data"}
lib_rule(name = name, data = data, **new_kwargs)
def py_extension(name, srcs, copts, deps, linkopts = []):
nanobind_extension(name, srcs = srcs, copts = copts, linkopts = linkopts, deps = deps, module_name = name)
@ -321,8 +325,8 @@ def jax_generate_backend_suites(backends = []):
tags = ["-jax_test_%s" % backend for backend in backends] + ["-manual"],
)
def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_version):
if no_abi:
def _get_full_wheel_name(package_name, no_abi, platform_independent, platform_name, cpu_name, wheel_version):
if no_abi or platform_independent:
wheel_name_template = "{package_name}-{wheel_version}-py{major_python_version}-none-{wheel_platform_tag}.whl"
else:
wheel_name_template = "{package_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl"
@ -332,7 +336,9 @@ def _get_full_wheel_name(package_name, no_abi, platform_name, cpu_name, wheel_ve
python_version = python_version,
major_python_version = python_version[0],
wheel_version = wheel_version,
wheel_platform_tag = "_".join(PLATFORM_TAGS_DICT[platform_name, cpu_name]),
wheel_platform_tag = "any" if platform_independent else "_".join(
PLATFORM_TAGS_DICT[platform_name, cpu_name],
),
)
def _jax_wheel_impl(ctx):
@ -360,10 +366,13 @@ def _jax_wheel_impl(ctx):
env["JAX_RELEASE"] = "1"
cpu = ctx.attr.cpu
no_abi = ctx.attr.no_abi
platform_independent = ctx.attr.platform_independent
platform_name = ctx.attr.platform_name
wheel_name = _get_full_wheel_name(
package_name = ctx.attr.wheel_name,
no_abi = ctx.attr.no_abi,
no_abi = no_abi,
platform_independent = platform_independent,
platform_name = platform_name,
cpu_name = cpu,
wheel_version = full_wheel_version,
@ -373,7 +382,8 @@ def _jax_wheel_impl(ctx):
wheel_dir = output_file.path[:output_file.path.rfind("/")]
args.add("--output_path", wheel_dir) # required argument
args.add("--cpu", cpu) # required argument
if not platform_independent:
args.add("--cpu", cpu)
args.add("--jaxlib_git_hash", git_hash) # required argument
if ctx.attr.enable_cuda:
@ -389,14 +399,21 @@ def _jax_wheel_impl(ctx):
if ctx.attr.skip_gpu_kernels:
args.add("--skip_gpu_kernels")
srcs = []
for src in ctx.attr.source_files:
for f in src.files.to_list():
srcs.append(f)
args.add("--srcs=%s" % (f.path))
args.set_param_file_format("flag_per_line")
args.use_param_file("@%s", use_always = False)
ctx.actions.run(
arguments = [args],
inputs = [],
inputs = srcs,
outputs = [output_file],
executable = executable,
env = env,
mnemonic = "BuildJaxWheel",
)
return [DefaultInfo(files = depset(direct = [output_file]))]
@ -411,9 +428,11 @@ _jax_wheel = rule(
),
"wheel_name": attr.string(mandatory = True),
"no_abi": attr.bool(default = False),
"platform_independent": attr.bool(default = False),
"cpu": attr.string(mandatory = True),
"platform_name": attr.string(mandatory = True),
"git_hash": attr.label(default = Label("//jaxlib/tools:jaxlib_git_hash")),
"source_files": attr.label_list(allow_files = True),
"output_path": attr.label(default = Label("//jaxlib/tools:output_path")),
"enable_cuda": attr.bool(default = False),
# A cuda/rocm version is required for gpu wheels; for cpu wheels, it can be an empty string.
@ -427,7 +446,15 @@ _jax_wheel = rule(
executable = False,
)
def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = False, platform_version = ""):
def jax_wheel(
name,
wheel_binary,
wheel_name,
no_abi = False,
platform_independent = False,
enable_cuda = False,
platform_version = "",
source_files = []):
"""Create jax artifact wheels.
Common artifact attributes are grouped within a single macro.
@ -437,8 +464,10 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals
wheel_binary: the binary to use to build the wheel
wheel_name: the name of the wheel
no_abi: whether to build a wheel without ABI
platform_independent: whether to build a wheel without platform tag
enable_cuda: whether to build a cuda wheel
platform_version: the cuda version to use for the wheel
source_files: the source files to include in the wheel
Returns:
A directory containing the wheel
@ -448,6 +477,7 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals
wheel_binary = wheel_binary,
wheel_name = wheel_name,
no_abi = no_abi,
platform_independent = platform_independent,
enable_cuda = enable_cuda,
platform_version = platform_version,
# git_hash is empty by default. Use `--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD)`
@ -465,6 +495,7 @@ def jax_wheel(name, wheel_binary, wheel_name, no_abi = False, enable_cuda = Fals
"//jaxlib/tools:arm64": "aarch64",
"@platforms//cpu:x86_64": "x86_64",
}),
source_files = source_files,
)
jax_test_file_visibility = []

View File

@ -1650,6 +1650,18 @@ jax_py_test(
],
)
jax_multiplatform_test(
name = "serialization_test",
srcs = ["serialization_test.py"],
enable_configs = [
"tpu_v3_2x2",
],
deps = [
"//jax:experimental",
"//jax/experimental/array_serialization:serialization",
],
)
exports_files(
[
"api_test.py",