rocm_jax/jax/version.py
Nitin Srinivasan 64e4bf2632 Relax jax dependency constraints to be able to install RC wheels
Also, add a job to the release test workflow that verifies that the release wheels can be installed.

TESTED:
1. Full release: https://github.com/jax-ml/jax/actions/runs/14315832784

2. jax only release: https://github.com/jax-ml/jax/actions/runs/14316157252

PiperOrigin-RevId: 744857804
2025-04-07 14:49:55 -07:00

162 lines
6.6 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is included as part of both jax and jaxlib. It is also
# eval()-ed by setup.py, so it should not have any dependencies.
from __future__ import annotations
import datetime
import os
import pathlib
import subprocess
_version = "0.5.4"
# The following line is overwritten by build scripts in distributions &
# releases. Do not modify this manually, or jax/jaxlib build will fail.
_release_version: str | None = None
# The following line is overwritten by build scripts in distributions &
# releases. Do not modify this manually, or jax/jaxlib build will fail.
_git_hash: str | None = None
def _get_version_string() -> str:
# The build/source distribution for jax & jaxlib overwrites _release_version.
# In this case we return it directly.
if _release_version is not None:
return _release_version
if os.getenv("WHEEL_VERSION_SUFFIX"):
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
return _version_from_git_tree(_version) or _version_from_todays_date(_version)
def _version_from_todays_date(base_version: str) -> str:
datestring = datetime.date.today().strftime("%Y%m%d")
return f"{base_version}.dev{datestring}"
def _version_from_git_tree(base_version: str) -> str | None:
try:
root_directory = os.path.dirname(os.path.realpath(__file__))
# Get date string from date of most recent git commit, and the abbreviated
# hash of that commit.
p = subprocess.Popen(["git", "show", "-s", "--format=%at-%h", "HEAD"],
cwd=root_directory,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, _ = p.communicate()
timestamp, commit_hash = stdout.decode().strip().split('-', 1)
datestring = datetime.date.fromtimestamp(int(timestamp)).strftime("%Y%m%d")
assert datestring.isnumeric()
assert commit_hash.isalnum()
except:
return None
else:
version = f"{base_version}.dev{datestring}+{commit_hash}"
suffix = os.environ.get("JAX_CUSTOM_VERSION_SUFFIX", None)
if suffix:
return version + "." + suffix
return version
def _get_version_for_build() -> str:
"""Determine the version at build time.
The returned version string depends on which environment variables are set:
- if WHEEL_VERSION_SUFFIX is set: version looks like "0.5.1.dev20230906+ge58560fdc"
Here the WHEEL_VERSION_SUFFIX value is ".dev20230906+ge58560fdc".
Please note that the WHEEL_VERSION_SUFFIX value is not the same as the
JAX_CUSTOM_VERSION_SUFFIX value, and WHEEL_VERSION_SUFFIX is set by Bazel
wheel build rule.
- if JAX_RELEASE or JAXLIB_RELEASE are set: version looks like "0.4.16"
- if JAX_NIGHTLY or JAXLIB_NIGHTLY are set: version looks like "0.4.16.dev20230906"
- if none are set: version looks like "0.4.16.dev20230906+ge58560fdc
"""
if _release_version is not None:
return _release_version
if os.getenv("WHEEL_VERSION_SUFFIX"):
return _version + os.getenv("WHEEL_VERSION_SUFFIX", "")
if os.getenv("JAX_RELEASE") or os.getenv("JAXLIB_RELEASE"):
return _version
if os.getenv("JAX_NIGHTLY") or os.getenv("JAXLIB_NIGHTLY"):
return _version_from_todays_date(_version)
return _version_from_git_tree(_version) or _version_from_todays_date(_version)
def _is_prerelease() -> bool:
"""Determine if this is a pre-release ("rc" wheels) build."""
rc_version = os.getenv("WHEEL_VERSION_SUFFIX", "")
return True if rc_version.startswith("rc") else False
def _write_version(fname: str) -> None:
"""Used by setup.py to write the specified version info into the source tree."""
release_version = _get_version_for_build()
old_version_string = "_release_version: str | None = None"
new_version_string = f"_release_version: str = {release_version!r}"
fhandle = pathlib.Path(fname)
contents = fhandle.read_text()
# Expect two occurrences: one above, and one here.
if contents.count(old_version_string) != 2:
raise RuntimeError(f"Build: could not find {old_version_string!r} in {fname}")
contents = contents.replace(old_version_string, new_version_string)
githash = os.environ.get("JAX_GIT_HASH")
if githash:
old_githash_string = "_git_hash: str | None = None"
new_githash_string = f"_git_hash: str = {githash!r}"
if contents.count(old_githash_string) != 2:
raise RuntimeError(f"Build: could not find {old_githash_string!r} in {fname}")
contents = contents.replace(old_githash_string, new_githash_string)
fhandle.write_text(contents)
def _get_cmdclass(pkg_source_path):
from setuptools.command.build_py import build_py as build_py_orig # pytype: disable=import-error
from setuptools.command.sdist import sdist as sdist_orig # pytype: disable=import-error
class _build_py(build_py_orig):
def run(self):
if _release_version is None:
this_file_in_build_dir = os.path.join(self.build_lib, pkg_source_path,
os.path.basename(__file__))
# super().run() only copies files from source -> build if they are
# missing or outdated. Because _write_version(...) modifies the copy of
# this file in the build tree, re-building from the same JAX directory
# would not automatically re-copy a clean version, and _write_version
# would fail without this deletion. See jax-ml/jax#18252.
if os.path.isfile(this_file_in_build_dir):
os.unlink(this_file_in_build_dir)
super().run()
if _release_version is None:
_write_version(this_file_in_build_dir)
class _sdist(sdist_orig):
def make_release_tree(self, base_dir, files):
super().make_release_tree(base_dir, files)
if _release_version is None:
_write_version(os.path.join(base_dir, pkg_source_path,
os.path.basename(__file__)))
return dict(sdist=_sdist, build_py=_build_py)
__version__ = _get_version_string()
_minimum_jaxlib_version = "0.5.3"
def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
__version_info__ = _version_as_tuple(__version__)
_minimum_jaxlib_version_info = _version_as_tuple(_minimum_jaxlib_version)