Add the nightly dev version to __version__ of jaxlib.

PiperOrigin-RevId: 448001375
This commit is contained in:
Yash Katariya 2022-05-11 08:34:46 -07:00 committed by jax authors
parent 90f926ac6b
commit 46d034baab
2 changed files with 21 additions and 4 deletions

View File

@ -18,10 +18,13 @@
# Most users should not run this script directly; use build.py instead.
import argparse
import datetime
import functools
import glob
import os
import pathlib
import platform
import re
import shutil
import subprocess
import sys
@ -274,6 +277,22 @@ def prepare_wheel(sources_path):
patch_copy_tpu_client_py(jaxlib_dir)
def edit_jaxlib_version(sources_path):
version_regex = re.compile(r'__version__ = \"(.*)\"')
version_file = pathlib.Path(sources_path) / "jaxlib" / "version.py"
content = version_file.read_text()
version_num = version_regex.search(content).group(1)
datestring = datetime.datetime.now().strftime('%Y%m%d')
nightly_version = f'{version_num}.dev{datestring}'
content = content.replace(f'__version__ = "{version_num}"',
f'__version__ = "{nightly_version}"')
version_file.write_text(content)
def build_wheel(sources_path, output_path, cpu):
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
platform_name, cpu_name = {
@ -288,6 +307,8 @@ def build_wheel(sources_path, output_path, cpu):
f"{sys.version_info.minor}")
platform_tag_arg = f"--plat-name={platform_name}_{cpu_name}"
cwd = os.getcwd()
if os.environ.get('JAXLIB_NIGHTLY'):
edit_jaxlib_version(sources_path)
os.chdir(sources_path)
subprocess.run([sys.executable, "setup.py", "bdist_wheel",
python_tag_arg, platform_tag_arg], check=True)

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
from setuptools import setup
import os
@ -30,9 +29,6 @@ if cuda_version and cudnn_version:
nightly = os.environ.get('JAXLIB_NIGHTLY')
if nightly:
project_name = 'jaxlib-nightly'
# Version as `X.Y.Z.dev20220510`
datestring = datetime.datetime.now().strftime('%Y%m%d')
__version__ = f'{__version__}.dev{datestring}'
setup(
name=project_name,