rocm_jax/jax/version.py

162 lines
6.6 KiB
Python
Raw Permalink Normal View History

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is included as part of both jax and jaxlib. It is also
# eval()-ed by setup.py, so it should not have any dependencies.
from __future__ import annotations
import datetime
import os
import pathlib
import subprocess
_version = "0.6.0"
# The following line is overwritten by build scripts in distributions &
# releases. Do not modify this manually, or jax/jaxlib build will fail.
_release_version: str | None = None
# The following line is overwritten by build scripts in distributions &
# releases. Do not modify this manually, or jax/jaxlib build will fail.
_git_hash: str | None = None
def _get_version_string() -> str:
# The build/source distribution for jax & jaxlib overwrites _release_version.
# In this case we return it directly.
if _release_version is not None:
return _release_version
Refactor JAX wheel build rules to control the wheel filename and maintain reproducible wheel content and filename results. This change is a part of the initiative to test the JAX wheels in the presubmit properly. The list of the changes: 1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`. 2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase. 3. The version suffix of the wheel in the build rule output depends on the environment variables. The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables. 4. Environment variables combinations for creating wheels with different versions: * `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot` * `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release` * `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1` * `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)` PiperOrigin-RevId: 723552265
2025-02-05 10:00:49 -08:00
if os.getenv("WHEEL_VERSION_SUFFIX"):
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
return _version_from_git_tree(_version) or _version_from_todays_date(_version)
def _version_from_todays_date(base_version: str) -> str:
datestring = datetime.date.today().strftime("%Y%m%d")
return f"{base_version}.dev{datestring}"
def _version_from_git_tree(base_version: str) -> str | None:
try:
root_directory = os.path.dirname(os.path.realpath(__file__))
# Get date string from date of most recent git commit, and the abbreviated
# hash of that commit.
p = subprocess.Popen(["git", "show", "-s", "--format=%at-%h", "HEAD"],
cwd=root_directory,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, _ = p.communicate()
timestamp, commit_hash = stdout.decode().strip().split('-', 1)
datestring = datetime.date.fromtimestamp(int(timestamp)).strftime("%Y%m%d")
assert datestring.isnumeric()
assert commit_hash.isalnum()
except:
return None
else:
version = f"{base_version}.dev{datestring}+{commit_hash}"
suffix = os.environ.get("JAX_CUSTOM_VERSION_SUFFIX", None)
if suffix:
return version + "." + suffix
return version
def _get_version_for_build() -> str:
"""Determine the version at build time.
The returned version string depends on which environment variables are set:
Refactor JAX wheel build rules to control the wheel filename and maintain reproducible wheel content and filename results. This change is a part of the initiative to test the JAX wheels in the presubmit properly. The list of the changes: 1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`. 2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase. 3. The version suffix of the wheel in the build rule output depends on the environment variables. The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables. 4. Environment variables combinations for creating wheels with different versions: * `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot` * `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release` * `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1` * `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)` PiperOrigin-RevId: 723552265
2025-02-05 10:00:49 -08:00
- if WHEEL_VERSION_SUFFIX is set: version looks like "0.5.1.dev20230906+ge58560fdc"
Here the WHEEL_VERSION_SUFFIX value is ".dev20230906+ge58560fdc".
Please note that the WHEEL_VERSION_SUFFIX value is not the same as the
JAX_CUSTOM_VERSION_SUFFIX value, and WHEEL_VERSION_SUFFIX is set by Bazel
wheel build rule.
- if JAX_RELEASE or JAXLIB_RELEASE are set: version looks like "0.4.16"
- if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906"
- if none are set: version looks like "0.4.16.dev20230906+ge58560fdc
"""
if _release_version is not None:
return _release_version
Refactor JAX wheel build rules to control the wheel filename and maintain reproducible wheel content and filename results. This change is a part of the initiative to test the JAX wheels in the presubmit properly. The list of the changes: 1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`. 2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase. 3. The version suffix of the wheel in the build rule output depends on the environment variables. The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables. 4. Environment variables combinations for creating wheels with different versions: * `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot` * `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release` * `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1` * `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)` PiperOrigin-RevId: 723552265
2025-02-05 10:00:49 -08:00
if os.getenv("WHEEL_VERSION_SUFFIX"):
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
if os.getenv("JAX_RELEASE") or os.getenv("JAXLIB_RELEASE"):
return _version
Refactor JAX wheel build rules to control the wheel filename and maintain reproducible wheel content and filename results. This change is a part of the initiative to test the JAX wheels in the presubmit properly. The list of the changes: 1. JAX wheel build rule verifies that `--@local_config_cuda//cuda:include_cuda_libs=false` during the wheel build. There is a way to pass the restriction by providing `--@local_config_cuda//cuda:override_include_cuda_libs=true`. 2. The JAX version number (which is also used in the wheel filenames) is stored in `_version` variable in the file [version.py](https://github.com/jax-ml/jax/blob/main/jax/version.py). The custom repository rule `jax_python_wheel_version_repository` saves this value in `wheel_version.bzl`, so it becomes available in Bazel build phase. 3. The version suffix of the wheel in the build rule output depends on the environment variables. The version suffix chunks that are not reproducible shouldn’t be calculated as a part of the wheel binary: for example, the current date changes every day, thus the wheels built today and tomorrow on the same code version will be technically different. To maintain reproducible wheel content, we need to pass suffix chunks in a form of environment variables. 4. Environment variables combinations for creating wheels with different versions: * `0.5.1.dev0+selfbuilt` (local build, default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot` * `0.5.1` (release): `--repo_env=ML_WHEEL_TYPE=release` * `0.5.1rc1` (release candidate): `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=rc1` * `0.5.1.dev20250128+3e75e20c7` (nightly build): `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=20250128 --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)` PiperOrigin-RevId: 723552265
2025-02-05 10:00:49 -08:00
if os.getenv("JAX_NIGHTLY") or os.getenv("JAXLIB_NIGHTLY"):
return _version_from_todays_date(_version)
return _version_from_git_tree(_version) or _version_from_todays_date(_version)
def _is_prerelease() -> bool:
"""Determine if this is a pre-release ("rc" wheels) build."""
rc_version = os.getenv("WHEEL_VERSION_SUFFIX", "")
return True if rc_version.startswith("rc") else False
def _write_version(fname: str) -> None:
"""Used by setup.py to write the specified version info into the source tree."""
release_version = _get_version_for_build()
old_version_string = "_release_version: str | None = None"
new_version_string = f"_release_version: str = {release_version!r}"
fhandle = pathlib.Path(fname)
contents = fhandle.read_text()
2023-09-22 14:54:31 -07:00
# Expect two occurrences: one above, and one here.
if contents.count(old_version_string) != 2:
raise RuntimeError(f"Build: could not find {old_version_string!r} in {fname}")
contents = contents.replace(old_version_string, new_version_string)
githash = os.environ.get("JAX_GIT_HASH")
if githash:
old_githash_string = "_git_hash: str | None = None"
new_githash_string = f"_git_hash: str = {githash!r}"
if contents.count(old_githash_string) != 2:
raise RuntimeError(f"Build: could not find {old_githash_string!r} in {fname}")
contents = contents.replace(old_githash_string, new_githash_string)
fhandle.write_text(contents)
def _get_cmdclass(pkg_source_path):
from setuptools.command.build_py import build_py as build_py_orig # pytype: disable=import-error
from setuptools.command.sdist import sdist as sdist_orig # pytype: disable=import-error
class _build_py(build_py_orig):
def run(self):
if _release_version is None:
this_file_in_build_dir = os.path.join(self.build_lib, pkg_source_path,
os.path.basename(__file__))
# super().run() only copies files from source -> build if they are
# missing or outdated. Because _write_version(...) modifies the copy of
# this file in the build tree, re-building from the same JAX directory
# would not automatically re-copy a clean version, and _write_version
# would fail without this deletion. See jax-ml/jax#18252.
if os.path.isfile(this_file_in_build_dir):
os.unlink(this_file_in_build_dir)
super().run()
if _release_version is None:
_write_version(this_file_in_build_dir)
class _sdist(sdist_orig):
def make_release_tree(self, base_dir, files):
super().make_release_tree(base_dir, files)
if _release_version is None:
_write_version(os.path.join(base_dir, pkg_source_path,
os.path.basename(__file__)))
return dict(sdist=_sdist, build_py=_build_py)
__version__ = _get_version_string()
2025-03-18 21:38:14 -04:00
_minimum_jaxlib_version = "0.5.3"
def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
__version_info__ = _version_as_tuple(__version__)
_minimum_jaxlib_version_info = _version_as_tuple(_minimum_jaxlib_version)